From 059f91e9669be370fc104bcd2d87217407e29962 Mon Sep 17 00:00:00 2001 From: Xiaochun Tong Date: Thu, 24 Oct 2024 00:48:59 -0400 Subject: [PATCH] thinking about refactor parsiing & type inferencing for better metaprogramming support --- luisa_lang/codegen/cpp.py | 2 ++ luisa_lang/hir/defs.py | 7 ++++-- luisa_lang/hir/infer.py | 8 +++++-- luisa_lang/lang.py | 3 ++- luisa_lang/parse.py | 48 +++++++++++++++++++++++++++++++++------ setup.py | 3 ++- 6 files changed, 58 insertions(+), 13 deletions(-) diff --git a/luisa_lang/codegen/cpp.py b/luisa_lang/codegen/cpp.py index 532f28a..33fac90 100644 --- a/luisa_lang/codegen/cpp.py +++ b/luisa_lang/codegen/cpp.py @@ -119,6 +119,7 @@ def mangle_impl(self, obj: Union[hir.Type, hir.FunctionLike]) -> str: case hir.VectorType(element=element, count=count): return f"V{count}{self.mangle(element)}" case hir.Function(name=name, params=params, return_type=ret): + assert ret name = mangle_name(name) return f'{name}_' + unique_hash(f"F{name}_{self.mangle(ret)}{''.join(self.mangle(unwrap(p.type)) for p in params)}") case hir.BuiltinFunction(name=name): @@ -186,6 +187,7 @@ def __init__(self, base: CppCodeGen, func: hir.Function) -> None: self.name = base.mangling.mangle(func) self.func = func params = ",".join(self.gen_var(p) for p in func.params) + assert func.return_type self.signature = f'extern "C" auto {self.name}({params}) -> {base.type_cache.gen(func.return_type)}' self.body = ScratchBuffer() self.params = set(p.name for p in func.params) diff --git a/luisa_lang/hir/defs.py b/luisa_lang/hir/defs.py index 8a76f04..30440a8 100644 --- a/luisa_lang/hir/defs.py +++ b/luisa_lang/hir/defs.py @@ -135,6 +135,9 @@ def __eq__(self, value: object) -> bool: def __hash__(self) -> int: return hash(UnitType) + + def __str__(self) -> str: + return "NoneType" class ScalarType(Type): @@ -1066,7 +1069,7 @@ class Function: name: str generic_params: Dict[str, GenericParameter] params: List[Var] - return_type: Type + return_type: Type | None body: List[Stmt] builtin: bool export: bool @@ -1079,7 +1082,7 @@ def __init__( name: str, generic_params: Dict[str, GenericParameter], params: List[Var], - return_type: Type, + return_type: Type | None, body: List[Stmt], locals: List[Var], builtin: bool = False, diff --git a/luisa_lang/hir/infer.py b/luisa_lang/hir/infer.py index ca28591..8c55f46 100644 --- a/luisa_lang/hir/infer.py +++ b/luisa_lang/hir/infer.py @@ -22,6 +22,8 @@ def wrapper(inferencer: 'FuncTypeInferencer', node: hir.TypedNode, *args) -> Opt def is_function_fully_typed(func: hir.Function) -> bool: + if not func.return_type: + return False for stmt in func.body: if not is_stmt_fully_typed(stmt): return False @@ -112,7 +114,9 @@ def infer_stmt(self, stmt: hir.Stmt) -> None: case hir.Return(value=value): if value: ty = self.infer_expr(value) - if self.func.return_type != ty: + if not self.func.return_type: + self.func.return_type = ty + elif self.func.return_type != ty: report_error( stmt.span, f"Return type mismatch: expected {self.func.return_type}, got {ty}", @@ -247,7 +251,7 @@ def _infer_call_helper( # traceback.print_exc() raise hir.TypeInferenceError( node, - f"Error during instantiating function template {f.name}: {e}") + f"Error during instantiating function template {f.name}: {e}") from e else: resolved_f = f.resolve(None) node.op = hir.Constant(resolved_f) diff --git a/luisa_lang/lang.py b/luisa_lang/lang.py index 0797f04..48e87b3 100644 --- a/luisa_lang/lang.py +++ b/luisa_lang/lang.py @@ -50,6 +50,7 @@ def parsing_func(args: hir.FunctionTemplateResolvingArgs) -> hir.FunctionLike: if is_generic: mapping = hir.match_func_template_args(func_sig, args) if len(mapping) != len(func_sig.generic_params): + print(mapping, func_sig.generic_params) raise hir.TypeInferenceError( None, "not all type parameters are resolved") for p in func_sig.generic_params.values(): @@ -57,7 +58,7 @@ def parsing_func(args: hir.FunctionTemplateResolvingArgs) -> hir.FunctionLike: raise hir.TypeInferenceError( None, f"type parameter {p} is not resolved") parsing_ctx.bound_type_vars[p.name] = mapping[p] - # print(f'binding {p.name} = {mapping[p]}') + print(f'binding {p.name} = {mapping[p]}') func_parser = parse.FuncParser(func_name, f, parsing_ctx, self_type) func_ir = func_parser.parse_body() hir.run_inference_on_function(func_ir) diff --git a/luisa_lang/parse.py b/luisa_lang/parse.py index 067a717..fa5cf67 100644 --- a/luisa_lang/parse.py +++ b/luisa_lang/parse.py @@ -154,6 +154,13 @@ def get_access_key() -> Optional[Tuple[AccessKind, List[AccessKey]]]: else: raise RuntimeError( "Associated type not supported by Python") + elif cur is typing.Any: + # generic + if access is None: + return cur, None + else: + raise RuntimeError( + "Associated type not supported by Python") if chain_idx is None: if len(self.chain) == 0: break @@ -175,6 +182,7 @@ class ParsingContext: bound_type_vars: Dict[str, Union[hir.Type, hir.Value]] type_vars: Dict[typing.TypeVar, Tuple[hir.GenericParameter, Union[hir.Type, hir.Value]]] + any_cnt: int def __init__(self, ctx_name: str, globals: Dict[str, Any]): self.globals = globals @@ -183,7 +191,8 @@ def __init__(self, ctx_name: str, globals: Dict[str, Any]): self.ctx_name = ctx_name self.type_vars = {} self.bound_type_vars = {} - + self.any_cnt = 0 + def __eval_name(self, name: str) -> Optional[Any]: try: if name in self.name_eval_cache: @@ -264,17 +273,38 @@ def check_is_access(tree: ast.AST) -> bool: return None # report_error(tree, f"unsupported access chain {tree}") - def parse_type(self, type_tree: ast.AST, allow_new_typevar: bool = False) -> Optional[Type]: + def parse_type(self, type_tree: ast.AST, is_sig_params: bool = False) -> Optional[Type]: + """ + `is_sig_params` should be set if is parsing function arguments + """ acess_chain: AccessChain | None = self._parse_access_chain( type_tree, True) if acess_chain is None: return None # print(acess_chain) resolved, remaining = acess_chain.resolve() - if remaining is not None: - report_error(type_tree, f"failed to resolve type. {resolved},{remaining}") + if remaining is not None and remaining != []: + report_error( + type_tree, f"failed to resolve type. {resolved},{remaining}") if isinstance(resolved, hir.Type): return resolved + if resolved == typing.Any: + if is_sig_params: + # create a new generic parameter + param = hir.GenericParameter( + f"Any#{self.any_cnt}", self.ctx_name, None) + type_var = typing.TypeVar( # type: ignore + f"Any#{self.any_cnt}", bound=Any) # type: ignore + self.any_cnt += 1 + if param.name in self.bound_type_vars: + any_ty = self.bound_type_vars[param.name] + assert isinstance(any_ty, hir.Type) + return any_ty + generic_ty = hir.SymbolicType(param) + self.type_vars[type_var] = (param, generic_ty) + return generic_ty + else: + return None if isinstance(resolved, typing.TypeVar): # if resolved.__name__ in self.bound_type_vars: # ty_or_val = self.bound_type_vars[resolved.__name__] @@ -289,7 +319,7 @@ def parse_type(self, type_tree: ast.AST, allow_new_typevar: bool = False) -> Opt else: report_error( type_tree, f"expected generic parameter {resolved} to be a type but got a value: {ty_or_val}") - elif allow_new_typevar: + elif is_sig_params: ty_bound: hir.TypeBound | None = None # create new type var constraints, bound = get_typevar_constrains_and_bounds( @@ -316,6 +346,7 @@ def parse_type(self, type_tree: ast.AST, allow_new_typevar: bool = False) -> Opt type_tree, f"undefined type parameter {resolved}. type parameter must be included in the function signature or class definition") return None + class FuncParser: p_ctx: ParsingContext vars: Dict[str, hir.Var] @@ -330,6 +361,7 @@ class FuncParser: def __init__(self, name: str, func: object, p_ctx: ParsingContext, self_type: Optional[Type] = None) -> None: obj_ast, _obj_file = retrieve_ast_and_filename(func) + print(ast.dump(obj_ast)) assert isinstance(obj_ast, ast.Module), f"{obj_ast} is not a module" if not isinstance(obj_ast.body[0], ast.FunctionDef): raise RuntimeError("Function definition expected.") @@ -346,7 +378,7 @@ def __init__(self, name: str, func: object, p_ctx: ParsingContext, self_type: Op self.signature_initialized = True # print(self.arg_types, "->", self.return_type) - assert self.return_type is not None + # assert self.return_type is not None generic_params: Dict[str, hir.GenericParameter] = {} for tv in self.p_ctx.type_vars: param, _ = self.p_ctx.type_vars[tv] @@ -401,9 +433,11 @@ def _init_signature( self.return_type = self.self_type else: if func.returns is None: + self.return_type = None + elif isinstance(func.returns,ast.Constant) and func.returns.value is None: self.return_type = hir.UnitType() else: - self.return_type = p_ctx.parse_type(func.returns, True) + self.return_type = p_ctx.parse_type(func.returns, False) def parse_const(self, const: ast.Constant) -> hir.Value: span = hir.Span.from_ast(const) diff --git a/setup.py b/setup.py index 5b142ed..ae28718 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,8 @@ from setuptools import setup, find_packages setup( - name="luisa_lang", + name="luisa-python-lang", + description="A New DSL Frontend for LuisaCompute", version="0.1", packages=find_packages(), package_data={"luisa_lang": ["py.typed"]},