Skip to content

Commit

Permalink
allowing type[T] to be passed to functions
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Dec 17, 2024
1 parent 4e956b3 commit 0e83814
Show file tree
Hide file tree
Showing 7 changed files with 321 additions and 185 deletions.
13 changes: 7 additions & 6 deletions luisa_lang/_builtin_decor.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class _ObjKind(Enum):


def _make_func_template(f: Callable[..., Any], func_name: str, func_sig: Optional[MethodType],
func_globals: Dict[str, Any], foreign_type_var_ns: Dict[TypeVar, hir.Type | hir.ComptimeValue],
func_globals: Dict[str, Any], foreign_type_var_ns: Dict[TypeVar, hir.Type],
props: hir.FuncProperties, self_type: Optional[hir.Type] = None):
# parsing_ctx = _parse.ParsingContext(func_name, func_globals)
# func_sig_parser = _parse.FuncParser(func_name, f, parsing_ctx, self_type)
Expand All @@ -88,8 +88,7 @@ def _make_func_template(f: Callable[..., Any], func_name: str, func_sig: Optiona
implicit_generic_params.add(p.param)

def parsing_func(args: hir.FunctionTemplateResolvingArgs) -> hir.Function:
type_var_ns: Dict[TypeVar, hir.Type |
hir.ComptimeValue] = foreign_type_var_ns.copy()
type_var_ns: Dict[TypeVar, hir.Type] = foreign_type_var_ns.copy()
mapped_implicit_type_params: Dict[str,
hir.Type] = dict()
assert func_sig is not None
Expand Down Expand Up @@ -166,7 +165,7 @@ def _dsl_func_impl(f: _TT, kind: _ObjKind, attrs: Dict[str, Any]) -> _TT:


_MakeTemplateFn = Callable[[List[hir.GenericParameter]], hir.Type]
_InstantiateFn = Callable[[List[Any]], hir.Type]
_InstantiateFn = Callable[[List[hir.Type]], hir.Type]


def _dsl_struct_impl(cls: type[_TT], attrs: Dict[str, Any], ir_ty_override: hir.Type | Tuple[_MakeTemplateFn, _InstantiateFn] | None = None, opqaue_override: str | None = None) -> type[_TT]:
Expand Down Expand Up @@ -202,6 +201,8 @@ def parse_fields(tp: parse.TypeParser, self_ty: hir.Type):

def parse_methods(type_var_ns: Dict[TypeVar, hir.Type | Any], self_ty: hir.Type,):
for name in cls_info.methods:
if name == '__setitem__': # __setitem__ is ignored deliberately
continue
method_object = getattr(cls, name)
props: hir.FuncProperties
if hasattr(method_object, '__luisa_func_props__'):
Expand All @@ -214,7 +215,7 @@ def parse_methods(type_var_ns: Dict[TypeVar, hir.Type | Any], self_ty: hir.Type,
method_object, get_full_name(method_object), cls_info.methods[name], globalns, type_var_ns, props, self_type=self_ty)
if isinstance(self_ty, hir.BoundType):
assert isinstance(self_ty.instantiated,
(hir.StructType, hir.OpaqueType))
(hir.ArrayType, hir.StructType, hir.OpaqueType))
self_ty.instantiated.methods[name] = template
else:
self_ty.methods[name] = template
Expand All @@ -235,7 +236,7 @@ def parse_methods(type_var_ns: Dict[TypeVar, hir.Type | Any], self_ty: hir.Type,
parse_fields(type_parser, ir_ty)
is_generic = len(cls_info.type_vars) > 0
if is_generic:
def monomorphization_func(args: List[hir.Type | Any]) -> hir.Type:
def monomorphization_func(args: List[hir.Type]) -> hir.Type:
assert isinstance(ir_ty, hir.ParametricType)
type_var_ns = {}
if len(args) != len(cls_info.type_vars):
Expand Down
16 changes: 15 additions & 1 deletion luisa_lang/classinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,21 @@ def __repr__(self):

def __eq__(self, other):
return isinstance(other, SelfType)

class LiteralType:
value: Any

def __init__(self, value: Any):
self.value = value

VarType = Union[TypeVar, Type, GenericInstance, UnionType, SelfType, AnyType]
def __repr__(self):
return f"Literal[{self.value}]"

def __eq__(self, other):
return isinstance(other, LiteralType) and self.value == other.value


VarType = Union[TypeVar, Type, GenericInstance, UnionType, SelfType, AnyType, LiteralType]


def subst_type(ty: VarType, env: Dict[TypeVar, VarType]) -> VarType:
Expand Down Expand Up @@ -204,6 +216,8 @@ def parse_type_hint(hint: Any) -> VarType:
return GenericInstance(origin, [parse_type_hint(arg) for arg in args])
elif origin is Union:
return UnionType([parse_type_hint(arg) for arg in typing.get_args(hint)])
elif origin is Literal:
return LiteralType(typing.get_args(hint)[0])
else:
raise RuntimeError(f"Unsupported origin type: {origin}")

Expand Down
87 changes: 53 additions & 34 deletions luisa_lang/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,12 @@ def gen_impl(self, ty: hir.Type) -> str:
raise RuntimeError("invalid float type")
case hir.BoolType():
return "lc_bool"
case hir.PointerType(element=element):
return f"lc_ptr<{self.gen(element)}>"
case hir.VectorType(element=element, count=count):
return f"{self.gen(element)}{count}>"
case hir.ArrayType(element=element, count=count):
return f"lc_array<{self.gen(element)}, {count}>"
case hir.StructType(name=name, fields=fields):
self.impl.writeln(f'struct {name} {{')
for field in fields:
Expand All @@ -80,9 +84,13 @@ def do():
assert ty.instantiated
return self.gen(ty.instantiated)
case hir.FunctionType():
return ''
name = f'func_{unique_hash(ty.func_like.name)}_t'
self.impl.writeln(f'struct {name} {{}}; // function type of {ty.func_like.name}')
return name
case hir.TypeConstructorType():
return ''
name = f'type_{unique_hash(self.gen(ty.inner))}_t'
self.impl.writeln(f'struct {name} {{}}; // type constructor of {ty.inner}')
return name
case hir.OpaqueType():
def do():
match ty.name:
Expand Down Expand Up @@ -171,8 +179,8 @@ def mangle_impl(self, obj: Union[hir.Type, hir.Function]) -> str:
case hir.Function(name=name, params=params, return_type=ret):
assert ret
name = mangle_name(name)
params = list(filter(lambda p: not isinstance(
p.type, (hir.FunctionType)), params))
# params = list(filter(lambda p: not isinstance(
# p.type, (hir.FunctionType)), params))
return f'{name}_' + unique_hash(f"F{name}_{self.mangle(ret)}{''.join(self.mangle(unwrap(p.type)) for p in params)}")
case hir.StructType(name=name):
return name
Expand All @@ -184,6 +192,10 @@ def mangle_impl(self, obj: Union[hir.Type, hir.Function]) -> str:
return self.mangle(obj.instantiated)
case hir.OpaqueType():
return obj.name
case hir.TypeConstructorType():
return self.mangle(obj.inner)
case hir.FunctionType():
return f'func_{unique_hash(obj.func_like.name)}'
case _:
raise NotImplementedError(f"unsupported object: {obj}")

Expand Down Expand Up @@ -246,7 +258,7 @@ def __init__(self, base: CppCodeGen, func: hir.Function) -> None:
self.name = base.mangling.mangle(func)
self.func = func
params = ",".join(self.gen_var(
p) for p in func.params if not isinstance(p.type, hir.FunctionType))
p) for p in func.params)
assert func.return_type

self.signature = f'auto {self.name}({params}) -> {base.type_cache.gen(func.return_type)}'
Expand Down Expand Up @@ -285,10 +297,13 @@ def do():
intrin_name = intrin.name
gened_args = [self.gen_value_or_ref(
arg) for arg in intrin.args]
if intrin_name == 'buffer_ref':
return f"{gened_args[0]}[{gened_args[1]}]"
else:
raise RuntimeError(f"unsupported intrinsic reference: {intrin_name}")
match intrin_name:
case 'buffer.ref' | 'array.ref':
return f"{gened_args[0]}[{gened_args[1]}]"
case 'buffer.size' | 'array.size':
return f"{gened_args[0]}.size"
case _:
raise RuntimeError(f"unsupported intrinsic reference: {intrin_name}")
return do()
case _:
raise NotImplementedError(f"unsupported reference: {ref}")
Expand All @@ -312,17 +327,40 @@ def gen_value_or_ref(self, value: hir.Value | hir.Ref) -> str:
def gen_node_checked(self, node: hir.Node) -> str:
if isinstance(node, hir.Constant):
return self.gen_expr(node)
if isinstance(node, hir.TypedNode) and isinstance(node.type, (hir.TypeConstructorType, hir.FunctionType)):
assert node.type
return f'{self.base.type_cache.gen(node.type)}{{}}'

return self.node_map[node]

def gen_expr(self, expr: hir.Value) -> str:
if expr.type and isinstance(expr.type, hir.FunctionType):
return ''
# if expr.type and isinstance(expr.type, hir.FunctionType):
# return ''
if isinstance(expr, hir.Constant):
value = expr.value
if isinstance(value, int):
return f"{value}"
elif isinstance(value, float):
return f"{value}f"
elif isinstance(value, bool):
return "true" if value else "false"
elif isinstance(value, str):
return f"\"{value}\""
elif isinstance(value, hir.Function):
return self.gen_func(value)
else:
raise NotImplementedError(
f"unsupported constant: {expr}")
if expr in self.node_map:
return self.node_map[expr]
vid = self.new_vid()

def impl() -> None:
match expr:
case hir.TypeValue() as type_value:
assert type_value.type
self.base.type_cache.gen(type_value.type)
return
case hir.Load() as load:
self.body.writeln(
f"const auto &v{vid} = {self.gen_ref(load.ref)}; // load")
Expand All @@ -337,36 +375,17 @@ def impl() -> None:
case hir.Call() as call:
op = self.gen_func(call.op)
args_s = ','.join(self.gen_value_or_ref(
arg) for arg in call.args if not isinstance(arg.type, hir.FunctionType))
arg) for arg in call.args)
if call.type != hir.UnitType():
self.body.writeln(
f"auto v{vid} = {op}({args_s});")
else:
self.body.writeln(f"{op}({args_s});")
case hir.Constant() as constant:
value = constant.value
if isinstance(value, int):
self.body.writeln(f"const auto v{vid} = {value};")
elif isinstance(value, float):
self.body.writeln(f"const auto v{vid} = {value};")
elif isinstance(value, bool):
s = "true" if value else "false"
self.body.writeln(f"const auto v{vid} = {s};")
elif isinstance(value, str):
self.body.writeln(f"const auto v{vid} = \"{value}\";")
elif isinstance(value, hir.Function):
name = self.gen_func(value)
self.body.writeln(f"auto&& v{vid} = {name};")
else:
raise NotImplementedError(
f"unsupported constant: {constant}")
case hir.AggregateInit():
assert expr.type
ty = self.base.type_cache.gen(expr.type)
self.body.writeln(
f"{ty} v{vid}{{ {','.join(self.gen_expr(e) for e in expr.args)} }};")
case hir.TypeValue():
pass
case hir.Intrinsic() as intrin:
def do():
intrin_name = intrin.name
Expand Down Expand Up @@ -544,7 +563,7 @@ def gen_node(self, node: hir.Node) -> Optional[hir.BasicBlock]:
ty = self.base.type_cache.gen(alloca.type)
self.body.writeln(f"{ty} v{vid}{{}};")
self.node_map[alloca] = f"v{vid}"
case hir.AggregateInit() | hir.Intrinsic() | hir.Call() | hir.Constant() | hir.Load() | hir.Index() | hir.Member():
case hir.AggregateInit() | hir.Intrinsic() | hir.Call() | hir.Constant() | hir.Load() | hir.Index() | hir.Member() | hir.TypeValue() | hir.FunctionValue():
self.gen_expr(node)
case hir.Member() | hir.Index():
pass
Expand All @@ -570,8 +589,8 @@ def gen_locals(self):
for local in self.func.locals:
if local.name in self.params:
continue
if isinstance(local.type, (hir.FunctionType, hir.TypeConstructorType)):
continue
# if isinstance(local.type, (hir.FunctionType, hir.TypeConstructorType)):
# continue
assert (
local.type
), f"Local variable `{local.name}` contains unresolved type"
Expand Down
Loading

0 comments on commit 0e83814

Please sign in to comment.