Skip to content

Commit

Permalink
refactoring parser
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Oct 28, 2024
1 parent 59100ff commit 915507c
Show file tree
Hide file tree
Showing 11 changed files with 532 additions and 1,583 deletions.
19 changes: 11 additions & 8 deletions luisa_lang/classinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Literal,
Optional,
Set,
Tuple,
TypeVar,
Generic,
Dict,
Expand Down Expand Up @@ -39,13 +40,15 @@ def __init__(self, types: List["VarType"]):
def __repr__(self):
return f"Union[{', '.join(map(repr, self.types))}]"


class AnyType:
def __repr__(self):
return "Any"

def __eq__(self, other):
return isinstance(other, AnyType)


class SelfType:
def __repr__(self):
return "Self"
Expand All @@ -69,13 +72,13 @@ def subst_type(ty: VarType, env: Dict[TypeVar, VarType]) -> VarType:

class MethodType:
type_vars: List[TypeVar]
args: List[VarType]
args: List[Tuple[str, VarType]]
return_type: VarType
env: Dict[TypeVar, VarType]
is_static: bool

def __init__(
self, type_vars: List[TypeVar], args: List[VarType], return_type: VarType, env: Optional[Dict[TypeVar, VarType]] = None, is_static: bool = False
self, type_vars: List[TypeVar], args: List[Tuple[str, VarType]], return_type: VarType, env: Optional[Dict[TypeVar, VarType]] = None, is_static: bool = False
):
self.type_vars = type_vars
self.args = args
Expand All @@ -88,7 +91,7 @@ def __repr__(self):
return f"[{', '.join(map(repr, self.type_vars))}]({', '.join(map(repr, self.args))}) -> {self.return_type}"

def substitute(self, env: Dict[TypeVar, VarType]) -> "MethodType":
return MethodType([], [subst_type(arg, env) for arg in self.args], subst_type(self.return_type, env), env)
return MethodType([], [(arg[0], subst_type(arg[1], env)) for arg in self.args], subst_type(self.return_type, env), env)


class ClassType:
Expand Down Expand Up @@ -229,15 +232,15 @@ def parse_func_signature(func: object, globalns: Dict[str, Any], foreign_type_va
assert inspect.isfunction(func)
signature = inspect.signature(func)
method_type_hints = typing.get_type_hints(func, globalns)
param_types: List[VarType] = []
param_types: List[Tuple[str, VarType]] = []
type_vars = get_type_vars(func)
for param in signature.parameters.values():
if param.name == "self":
assert self_type is not None
param_types.append(self_type)
param_types.append((param.name, self_type))
else:
param_types.append(parse_type_hint(
method_type_hints[param.name]))
param_types.append((param.name, parse_type_hint(
method_type_hints[param.name])))
if "return" in method_type_hints:
return_type = parse_type_hint(method_type_hints.get("return"))
else:
Expand Down
38 changes: 24 additions & 14 deletions luisa_lang/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
from luisa_lang.codegen import CodeGen, ScratchBuffer
from typing import Any, Callable, Dict, Set, Tuple, Union

from luisa_lang.hir.defs import GlobalContext
from luisa_lang.hir import get_dsl_func
from luisa_lang.hir.infer import run_inference_on_function


class TypeCodeGenCache:
Expand Down Expand Up @@ -150,11 +148,11 @@ def gen_function(self, func: hir.Function | Callable[..., Any]) -> str:
assert dsl_func is not None
assert not dsl_func.is_generic, f"Generic functions should be resolved before codegen: {func}"
func_tmp = dsl_func.resolve([])
assert isinstance(func_tmp, hir.Function), f"Expected function, got {func_tmp}"
assert isinstance(
func_tmp, hir.Function), f"Expected function, got {func_tmp}"
func = func_tmp
if id(func) in self.func_cache:
return self.func_cache[id(func)][1]
run_inference_on_function(func)
func_code_gen = FuncCodeGen(self, func)
name = func_code_gen.name
self.func_cache[id(func)] = (func, name)
Expand Down Expand Up @@ -207,17 +205,31 @@ def gen_ref(self, ref: hir.Ref) -> str:
case _:
raise NotImplementedError(f"unsupported reference: {ref}")

def gen_func(self, f: hir.FunctionLike) -> str:
if isinstance(f, hir.Function):
return self.base.gen_function(f)
elif isinstance(f, hir.BuiltinFunction):
return self.base.mangling.mangle(f)
else:
raise NotImplementedError(f"unsupported constant")

def gen_value_or_ref(self, value: hir.Value | hir.Ref) -> str:
match value:
case hir.Value() as value:
return self.gen_expr(value)
case hir.Ref() as ref:
return self.gen_ref(ref)
case _:
raise NotImplementedError(
f"unsupported value or reference: {value}")

def gen_expr(self, expr: hir.Value) -> str:
match expr:
case hir.Load() as load:
return self.gen_ref(load.ref)
case hir.Call() as call:
assert call.resolved, f"unresolved call: {call}"
kind = call.kind
assert kind == hir.CallOpKind.FUNC and isinstance(
call.op, hir.Value)
op = self.gen_expr(call.op)
return f"{op}({','.join(self.gen_expr(arg) for arg in call.args)})"
op = self.gen_func(call.op)
return f"{op}({','.join(self.gen_value_or_ref(arg) for arg in call.args)})"
case hir.Constant() as constant:
value = constant.value
if isinstance(value, int):
Expand All @@ -228,10 +240,8 @@ def gen_expr(self, expr: hir.Value) -> str:
return "true" if value else "false"
elif isinstance(value, str):
return f'"{value}"'
elif isinstance(value, hir.Function):
return self.base.gen_function(value)
elif isinstance(value, hir.BuiltinFunction):
return self.base.mangling.mangle(value)
elif isinstance(value, hir.Function) or isinstance(value, hir.BuiltinFunction):
return self.gen_func(value)
else:
raise NotImplementedError(
f"unsupported constant: {constant}")
Expand Down
Loading

0 comments on commit 915507c

Please sign in to comment.