From b51cfb66170705f254b6075c578865db5551e19f Mon Sep 17 00:00:00 2001 From: Xiaochun Tong Date: Tue, 5 Nov 2024 00:58:32 -0500 Subject: [PATCH] while loop and multi assignment --- luisa_lang/codegen/cpp.py | 23 ++- luisa_lang/hir.py | 16 ++- luisa_lang/lang.py | 6 +- luisa_lang/lang_builtins.py | 5 +- luisa_lang/parse.py | 278 +++++++++++++++++++++++++++--------- 5 files changed, 254 insertions(+), 74 deletions(-) diff --git a/luisa_lang/codegen/cpp.py b/luisa_lang/codegen/cpp.py index 20bb2a2..852989c 100644 --- a/luisa_lang/codegen/cpp.py +++ b/luisa_lang/codegen/cpp.py @@ -44,6 +44,16 @@ def gen_impl(self, ty: hir.Type) -> str: return name case hir.UnitType(): return 'void' + case hir.TupleType(): + def do(): + elements = [self.gen(e) for e in ty.elements] + name = f'Tuple_{unique_hash("".join(elements))}' + self.impl.writeln(f'struct {name} {{') + for i, element in enumerate(elements): + self.impl.writeln(f' {element} _{i};') + self.impl.writeln('};') + return name + return do() case _: raise NotImplementedError(f"unsupported type: {ty}") @@ -129,6 +139,9 @@ def mangle_impl(self, obj: Union[hir.Type, hir.FunctionLike]) -> str: return f"__builtin_{name}" case hir.StructType(name=name): return name + case hir.TupleType(): + elements = [self.mangle(e) for e in obj.elements] + return f"T{unique_hash(''.join(elements))}" case _: raise NotImplementedError(f"unsupported object: {obj}") @@ -275,6 +288,10 @@ def impl() -> None: else: raise NotImplementedError( f"unsupported constant: {constant}") + case hir.AggregateInit(): + assert expr.type + ty = self.base.type_cache.gen(expr.type) + self.body.writeln(f"{ty} v{vid}{{}};") case _: raise NotImplementedError( f"unsupported expression: {expr}") @@ -310,12 +327,12 @@ def gen_node(self, node: hir.Node): vid = self.new_vid() self.body.write(f"auto loop{vid}_prepare = [&]()->bool {{") self.body.indent += 1 - self.gen_bb(loop.prepare) + self.gen_bb(loop.prepare) if loop.cond: self.body.writeln(f"return {self.gen_expr(loop.cond)};") else: self.body.writeln("return true;") - self.body.indent -=1 + self.body.indent -= 1 self.body.writeln("};") self.body.writeln(f"auto loop{vid}_body = [&]() {{") self.body.indent += 1 @@ -354,7 +371,7 @@ def gen_locals(self): continue assert ( local.type - ), f"Local variable {local.name} contains unresolved type, please resolve it via TypeInferencer" + ), f"Local variable `{local.name}` contains unresolved type" self.body.writeln( f"{self.base.type_cache.gen(local.type)} {local.name}{{}};" ) diff --git a/luisa_lang/hir.py b/luisa_lang/hir.py index 540ec34..6a5d575 100644 --- a/luisa_lang/hir.py +++ b/luisa_lang/hir.py @@ -12,7 +12,6 @@ Tuple, Dict, Union, - cast, ) import typing from typing_extensions import override @@ -128,6 +127,9 @@ def method(self, name: str) -> Optional[FunctionLike | FunctionTemplate]: def is_concrete(self) -> bool: return True + + def __len__(self) -> int: + return 1 class UnitType(Type): @@ -337,7 +339,9 @@ def member(self, field: Any) -> Optional['Type']: return self.element return Type.member(self, field) - + def __len__(self) -> int: + return self.count + class ArrayType(Type): element: Type count: Union[int, "SymbolicConstant"] @@ -789,7 +793,7 @@ class Index(Value): index: Value def __init__(self, base: Value, index: Value, type: Type, span: Optional[Span]) -> None: - super().__init__(None, span) + super().__init__(type, span) self.base = base self.index = index @@ -857,6 +861,12 @@ def __init__(self, ty: Type, span: Optional[Span] = None) -> None: # super().__init__(ty, span) # self.init_call = init_call +class AggregateInit(Value): + args: List[Value] + + def __init__(self, args: List[Value], type: Type, span: Optional[Span] = None) -> None: + super().__init__(type, span) + self.args = args class Call(Value): op: FunctionLike diff --git a/luisa_lang/lang.py b/luisa_lang/lang.py index 9069a7d..6af2d6a 100644 --- a/luisa_lang/lang.py +++ b/luisa_lang/lang.py @@ -1,6 +1,7 @@ from luisa_lang.classinfo import VarType, GenericInstance, UnionType, _get_cls_globalns, register_class, class_typeinfo from enum import Enum, auto from typing_extensions import TypeAliasType +import typing from typing import ( Callable, Dict, @@ -14,7 +15,6 @@ Union, Generic, Literal, - cast, overload, Any, ) @@ -109,7 +109,7 @@ def _dsl_func_impl(f: _T, kind: _ObjKind, attrs: Dict[str, Any]) -> _T: template = _make_func_template(f, func_name, func_globals) ctx.functions[f] = template setattr(f, "__luisa_func__", template) - return cast(_T, f) + return typing.cast(_T, f) else: raise NotImplementedError() # return cast(_T, f) @@ -150,7 +150,7 @@ def get_ir_type(var_ty: VarType) -> hir.Type: def _dsl_decorator_impl(obj: _T, kind: _ObjKind, attrs: Dict[str, Any]) -> _T: if kind == _ObjKind.STRUCT: assert isinstance(obj, type), f"{obj} is not a type" - return cast(_T, _dsl_struct_impl(obj, attrs)) + return typing.cast(_T, _dsl_struct_impl(obj, attrs)) elif kind == _ObjKind.FUNC or kind == _ObjKind.KERNEL: return _dsl_func_impl(obj, kind, attrs) raise NotImplementedError() diff --git a/luisa_lang/lang_builtins.py b/luisa_lang/lang_builtins.py index 2e6e02f..eab0ec3 100644 --- a/luisa_lang/lang_builtins.py +++ b/luisa_lang/lang_builtins.py @@ -23,7 +23,7 @@ def block_id() -> uint3: @_builtin -def convert(target: type[_T], value: Any) -> _T: +def cast(target: type[_T], value: Any) -> _T: """ Attempt to convert the value to the target type. """ @@ -185,4 +185,7 @@ def value(self, value: _T) -> None: 'static_assert', 'type_of_opt', 'typeof', + "dispatch_id", + "thread_id", + "block_id", ] diff --git a/luisa_lang/parse.py b/luisa_lang/parse.py index 1a725f5..34be529 100644 --- a/luisa_lang/parse.py +++ b/luisa_lang/parse.py @@ -5,9 +5,10 @@ import typing import luisa_lang from luisa_lang.lang_builtins import comptime -from luisa_lang.utils import get_typevar_constrains_and_bounds +from luisa_lang.utils import get_typevar_constrains_and_bounds, unwrap import luisa_lang.hir as hir import sys +from copy import copy from luisa_lang.utils import retrieve_ast_and_filename from luisa_lang.hir import ( Type, @@ -48,7 +49,7 @@ class TypeParser: implicit_type_params: Dict[str, hir.Type] def __init__(self, ctx_name: str, globalns: Dict[str, Any], type_var_ns: Dict[typing.TypeVar, hir.Type | ComptimeValue], self_type: Optional[Type] = None) -> None: - self.globalns = globalns + self.globalns = copy(globalns) self.self_type = self_type self.type_var_ns = type_var_ns self.ctx_name = ctx_name @@ -96,7 +97,8 @@ def convert_func_signature(signature: classinfo.MethodType, mode: Literal['parse', 'instantiate'] = 'parse' ) -> Tuple[hir.FunctionSignature, TypeParser]: """ - implicit_type_params: Tuple[List[Tuple[str, classinfo.VarType]], classinfo.VarType] + implicit_type_params: Tuple[List[Tuple[str, + classinfo.VarType]], classinfo.VarType] """ type_parser = TypeParser(ctx_name, globalns, type_var_ns, self_type) type_parser.implicit_type_params = implicit_type_params @@ -134,6 +136,8 @@ def convert_func_signature(signature: classinfo.MethodType, range } +NewVarHint = Literal[False, 'dsl', 'comptime'] + class FuncParser: name: str @@ -159,7 +163,7 @@ def __init__(self, name: str, self.signature = signature self.globalns = globalns obj_ast, _obj_file = retrieve_ast_and_filename(func) - # print(ast.dump(obj_ast)) + print(ast.dump(obj_ast)) assert isinstance(obj_ast, ast.Module), f"{obj_ast} is not a module" if not isinstance(obj_ast.body[0], ast.FunctionDef): raise RuntimeError("Function definition expected.") @@ -240,12 +244,12 @@ def convert_any_to_value(self, a: Any, span: hir.Span | None) -> hir.Value | Com raise hir.ParsingError( span, f"unsupported constant type {type(a.value)}, wrap it in lc.comptime(...) if you intead to use it as a compile-time expression") - def parse_name(self, name: ast.Name, maybe_new_var: bool) -> hir.Ref | hir.Value | ComptimeValue: + def parse_name(self, name: ast.Name, new_var_hint: NewVarHint) -> hir.Ref | hir.Value | ComptimeValue: span = hir.Span.from_ast(name) var = self.vars.get(name.id) if var is not None: return var - if maybe_new_var: + if new_var_hint == 'dsl': var = hir.Var(name.id, None, span) self.vars[name.id] = var return var @@ -254,7 +258,12 @@ def parse_name(self, name: ast.Name, maybe_new_var: bool) -> hir.Ref | hir.Value if name.id in self.globalns: resolved = self.globalns[name.id] return self.convert_any_to_value(resolved, span) - # assert isinstance(resolved, ComptimeValue), type(resolved) + elif new_var_hint == 'comptime': + self.globalns[name.id] = None + + def update_fn(value: Any) -> None: + self.globalns[name.id] = value + return ComptimeValue(None, update_fn) raise hir.ParsingError(name, f"unknown variable {name.id}") @@ -286,6 +295,7 @@ def parse_access_ref(self, expr: ast.Subscript | ast.Attribute) -> hir.Ref: if isinstance(expr, ast.Subscript): value = self.parse_ref(expr.value) index = self.parse_expr(expr.slice) + assert isinstance(value, hir.Ref) and isinstance(index, hir.Value) index = self.convert_to_value(index, span) assert value.type index_ty = self.get_index_type(span, value.type, index) @@ -305,6 +315,7 @@ def parse_access_ref(self, expr: ast.Subscript | ast.Attribute) -> hir.Ref: expr, f"indexing not supported for type {value.type}") elif isinstance(expr, ast.Attribute): value = self.parse_ref(expr.value) + assert isinstance(value, hir.Ref) attr_name = expr.attr assert value.type member_ty = value.type.member(attr_name) @@ -569,10 +580,10 @@ def infer_binop(name: str, rname: str) -> hir.Value: raise e from e return infer_binop(ops[0], ops[1]) - def parse_ref(self, expr: ast.expr, maybe_new_var: bool = False) -> hir.Ref: + def parse_ref(self, expr: ast.expr, new_var_hint: NewVarHint = False) -> hir.Ref | ComptimeValue: match expr: case ast.Name(): - ret = self.parse_name(expr, maybe_new_var) + ret = self.parse_name(expr, new_var_hint) if isinstance(ret, (hir.Value, ComptimeValue)): raise hir.ParsingError( expr, f"value cannot be used as reference") @@ -581,7 +592,103 @@ def parse_ref(self, expr: ast.expr, maybe_new_var: bool = False) -> hir.Ref: return self.parse_access_ref(expr) case _: raise hir.ParsingError( - expr, f"expression cannot be parsed as reference") + expr, f"expression {ast.dump(expr)} cannot be parsed as reference") + + # def parse_assignment_targets(self, targets: List[ast.expr], new_var_hint: NewVarHint) -> List[hir.Ref]: + # return [self.parse_ref(t, new_var_hint) for t in targets] + + # def assign(self, targets: List[hir.Ref], values: hir.Value | ComptimeValue) -> None: + # pass + + def parse_multi_assignment(self, + targets: List[ast.expr], + anno_ty_fn: List[Optional[Callable[..., hir.Type | None]]], + values: hir.Value | ComptimeValue) -> None: + if isinstance(values, ComptimeValue): + parsed_targets = [self.parse_ref(t, 'comptime') for t in targets] + + def do_assign(target: hir.Ref | ComptimeValue, value: ComptimeValue, i: int) -> None: + span = hir.Span.from_ast(targets[i]) + if isinstance(target, ComptimeValue): + target.update(value.value) + else: + self.cur_bb().append(hir.Assign(target, + self.try_convert_comptime_value(value, span))) + if len(parsed_targets) > 1: + if len(parsed_targets) != len(values.value): + raise hir.ParsingError( + targets[0], f"expected {len(parsed_targets)} values to unpack, got {len(values.value)}") + for i, t in enumerate(parsed_targets): + do_assign(t, values.value[i], + i) + else: + t = parsed_targets[0] + do_assign(t, values, 0) + else: + parsed_targets = [self.parse_ref(t, 'dsl') for t in targets] + is_all_dsl = all( + isinstance(t, hir.Ref) for t in parsed_targets) + if not is_all_dsl: + raise hir.ParsingError( + targets[0], "DSL value cannot be assigned to comptime variables") + assert values.type + ref_targets = cast(List[hir.Ref], parsed_targets) + + def do_unpack(length: int, extract_fn: Callable[[hir.Value, int, ast.expr], hir.Value]) -> None: + def check(i: int, val_type: hir.Type) -> None: + if len(anno_ty_fn) > 0 and (fn := anno_ty_fn[i]) is not None: + ty = fn() + if ty is None: + raise hir.ParsingError( + targets[i], f"unable to infer type of target") + if ref_targets[i].type is None: + ref_targets[i].type = ty + tt = ref_targets[i].type + if not tt: + if val_type.is_concrete(): + ref_targets[i].type = val_type + else: + raise hir.TypeInferenceError( + targets[i], f"unable to infer type of target, cannot assign with non-concrete type {val_type}") + elif not hir.is_type_compatible_to(val_type, tt): + raise hir.ParsingError( + targets[i], f"expected type {tt}, got {val_type}") + + if len(ref_targets) == 1: + assert values.type + check(0, values.type) + self.cur_bb().append(hir.Assign( + ref_targets[0], values)) + elif len(ref_targets) == length: + for i, t in enumerate(ref_targets): + e = extract_fn(values, i, targets[i]) + assert e.type + check(i, e.type) + self.cur_bb().append(hir.Assign(t, e)) + else: + if len(ref_targets) > length: + raise hir.ParsingError( + targets[0], f"too few values to unpack: expected {len(ref_targets)} values, got {length}") + else: + raise hir.ParsingError( + targets[0], f"too many values to unpack: expected {len(ref_targets)} values, got {length}") + match values.type: + case hir.VectorType() as vt: + comps = 'xyzw' + do_unpack(vt.count, lambda values, i, target: self.cur_bb().append( + hir.Member(values, comps[i], type=vt.element, span=hir.Span.from_ast(target)))) + case hir.TupleType() as tt: + do_unpack(len(tt.elements), lambda values, i, target: self.cur_bb().append( + hir.Member(values, f'_{i}', type=tt.elements[i], span=hir.Span.from_ast(target))) + ) + case hir.ArrayType() as at: + assert isinstance(at.count, int) + do_unpack(at.count, lambda values, i, target: self.cur_bb().append( + hir.Index(values, hir.Constant(i, type=hir.IntType(32, True)), type=at.element, span=hir.Span.from_ast(target)))) + case hir.StructType() as st: + do_unpack(len(st.fields), lambda values, i, target: self.cur_bb().append( + hir.Member(values, st.fields[i][0], type=st.fields[i][1], span=hir.Span.from_ast(target))) + ) def parse_expr(self, expr: ast.expr) -> hir.Value | ComptimeValue: match expr: @@ -598,6 +705,21 @@ def parse_expr(self, expr: ast.expr) -> hir.Value | ComptimeValue: return self.parse_binop(expr) case ast.Call(): return self.parse_call(expr) + case ast.Tuple(): + elements = [self.parse_expr(e) for e in expr.elts] + is_all_comptime = all( + isinstance(e, ComptimeValue) for e in elements) + if is_all_comptime: + return hir.ComptimeValue( + tuple(e.value for e in cast(List[ComptimeValue], elements)), None) + else: + for i, e in enumerate(elements): + if isinstance(e, ComptimeValue): + elements[i] = self.try_convert_comptime_value( + e, hir.Span.from_ast(expr.elts[i])) + tt: hir.TupleType = hir.TupleType( + [unwrap(e.type) for e in cast(List[hir.Value], elements)]) + return self.cur_bb().append(hir.AggregateInit(cast(List[hir.Value], elements), tt, span=hir.Span.from_ast(expr))) case _: raise RuntimeError(f"Unsupported expression: {ast.dump(expr)}") @@ -690,31 +812,38 @@ def check_return_type(ty: hir.Type) -> None: check_return_type(hir.UnitType()) self.cur_bb().append(hir.Return(None)) case ast.Assign(): - if len(stmt.targets) != 1: - raise hir.ParsingError(stmt, f"expected single target") + # if len(stmt.targets) != 1: + # raise hir.ParsingError(stmt, f"expected single target") + # target = stmt.targets[0] + # var = self.parse_ref(target, new_var_hint='dsl') + # value = self.parse_expr(stmt.value) + # if isinstance(var, ComptimeValue): + # if not isinstance(value, ComptimeValue): + # raise hir.ParsingError( + # stmt, f"comptime value cannot be assigned with DSL value") + # var.update(value.value) + # return None + # value = self.convert_to_value(value, span) + # assert value.type + # if var.type: + # if not hir.is_type_compatible_to(value.type, var.type): + # raise hir.ParsingError( + # stmt, f"expected {var.type}, got {value.type}") + # else: + # if not value.type.is_concrete(): + # raise hir.ParsingError( + # stmt, "only concrete type can be assigned, please annotate the variable with type hint") + # var.type = value.type + # self.cur_bb().append(hir.Assign(var, value, span)) + assert len(stmt.targets) == 1 target = stmt.targets[0] - var = self.parse_ref(target, maybe_new_var=True) - if isinstance(var, hir.Value): - raise hir.ParsingError(target, f"value cannot be assigned") - value = self.parse_expr(stmt.value) - if isinstance(var, ComptimeValue): - if not isinstance(value, ComptimeValue): - raise hir.ParsingError( - stmt, f"comptime value cannot be assigned with DSL value") - var.update(value.value) - return None - value = self.convert_to_value(value, span) - assert value.type - if var.type: - if not hir.is_type_compatible_to(value.type, var.type): - raise hir.ParsingError( - stmt, f"expected {var.type}, got {value.type}") + if isinstance(target, ast.Tuple): + self.parse_multi_assignment( + target.elts, [], self.parse_expr(stmt.value)) else: - if not value.type.is_concrete(): - raise hir.ParsingError( - stmt, "only concrete type can be assigned, please annotate the variable with type hint") - var.type = value.type - self.cur_bb().append(hir.Assign(var, value, span)) + self.parse_multi_assignment( + [target], [], self.parse_expr(stmt.value) + ) case ast.AugAssign(): method_name = AUG_ASSIGN_TO_METHOD_NAMES[type(stmt.op)] var = self.parse_ref(stmt.target) @@ -742,41 +871,60 @@ def check_return_type(ty: hir.Type) -> None: raise hir.ParsingError(stmt, ret.message) case ast.AnnAssign(): - var = self.parse_ref(stmt.target, maybe_new_var=True) - - type_annotation = self.eval_expr(stmt.annotation) - type_hint = classinfo.parse_type_hint(type_annotation) - ty = self.parse_type(type_hint) - assert ty - var.type = ty + def parse_anno_ty() -> hir.Type: + type_annotation = self.eval_expr(stmt.annotation) + type_hint = classinfo.parse_type_hint(type_annotation) + ty = self.parse_type(type_hint) + assert ty + return ty if stmt.value: - value = self.parse_expr(stmt.value) - - if isinstance(var, ComptimeValue): - if not isinstance(value, ComptimeValue): - raise hir.ParsingError( - stmt, f"comptime value cannot be assigned with DSL value") - var.update(value.value) - return None - if isinstance(value, ComptimeValue): - value = self.try_convert_comptime_value( - value, span) - elif isinstance(value, hir.Ref): - value = hir.Load(value) - assert value.type - assert ty - if not var.type.is_concrete(): - raise hir.ParsingError( - stmt, "only concrete type can be assigned, please annotate the variable with concrete types") - if not hir.is_type_compatible_to(value.type, ty): - raise hir.ParsingError( - stmt, f"expected {ty}, got {value.type}") - if not value.type.is_concrete(): - value.type = var.type - self.cur_bb().append(hir.Assign(var, value, span)) + self.parse_multi_assignment( + [stmt.target], [parse_anno_ty], self.parse_expr(stmt.value)) + # value = self.parse_expr(stmt.value) + # if isinstance(value, ComptimeValue): + # var = self.parse_ref( + # stmt.target, new_var_hint='comptime') + # else: + # var = self.parse_ref(stmt.target, new_var_hint='dsl') + # if isinstance(var, ComptimeValue): + # if isinstance(value, ComptimeValue): + # try: + # var.update(value.value) + # except Exception as e: + # raise hir.ParsingError( + # stmt, f"error updating comptime value: {e}") from e + # return + # else: + # raise hir.ParsingError( + # stmt, f"comptime value cannot be assigned with DSL value") + # else: + # if isinstance(value, ComptimeValue): + # value = self.try_convert_comptime_value( + # value, span) + # assert value.type + # anno_ty = parse_anno_ty() + # if not var.type: + # var.type = value.type + # if not var.type.is_concrete(): + # raise hir.ParsingError( + # stmt, "only concrete type can be assigned, please annotate the variable with concrete types") + # if not hir.is_type_compatible_to(value.type, anno_ty): + # raise hir.ParsingError( + # stmt, f"expected {anno_ty}, got {value.type}") + # if not value.type.is_concrete(): + # value.type = var.type + # self.cur_bb().append(hir.Assign(var, value, span)) else: + var = self.parse_ref(stmt.target, new_var_hint='dsl') + anno_ty = parse_anno_ty() assert isinstance(var, hir.Var) + if not var.type: + var.type = anno_ty + else: + if not hir.is_type_compatible_to(var.type, anno_ty): + raise hir.ParsingError( + stmt, f"expected {anno_ty}, got {var.type}") case ast.Expr(): self.parse_expr(stmt.value) case ast.Pass(): @@ -843,3 +991,5 @@ def parse_body(self): ast.RShift: ["__rshift__", "__rrshift__"], ast.Pow: ["__pow__", "__rpow__"], } + +__all__ = ["convert_func_signature", "FuncParser"]