Skip to content

Commit

Permalink
generic structs
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Nov 9, 2024
1 parent 699aff2 commit e3fec07
Show file tree
Hide file tree
Showing 5 changed files with 303 additions and 114 deletions.
103 changes: 75 additions & 28 deletions luisa_lang/_builtin_decor.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def make_builtin():


def builtin(s: str) -> Callable[[_F], _F]:
def wrapper(func: _F) -> _F:
def wrapper(func: _F) -> _F:
setattr(func, "__luisa_builtin__", s)
return func
return wrapper
Expand All @@ -148,15 +148,14 @@ def _intrinsic_impl(*args, **kwargs) -> Any:
)



class _ObjKind(Enum):
BUILTIN_TYPE = auto()
STRUCT = auto()
FUNC = auto()
KERNEL = auto()


def _make_func_template(f: Callable[..., Any], func_name: str, func_globals: Dict[str, Any], self_type: Optional[hir.Type] = None):
def _make_func_template(f: Callable[..., Any], func_name: str, func_globals: Dict[str, Any], foreign_type_var_ns: Dict[TypeVar, hir.Type | hir.ComptimeValue], self_type: Optional[hir.Type] = None):
# parsing_ctx = _parse.ParsingContext(func_name, func_globals)
# func_sig_parser = _parse.FuncParser(func_name, f, parsing_ctx, self_type)
# func_sig = func_sig_parser.parsed_func
Expand All @@ -165,15 +164,16 @@ def _make_func_template(f: Callable[..., Any], func_name: str, func_globals: Dic

func_sig = classinfo.parse_func_signature(f, func_globals, [])
func_sig_converted, sig_parser = parse.convert_func_signature(
func_sig, func_name, func_globals, {}, {}, self_type)
func_sig, func_name, func_globals, foreign_type_var_ns, {}, self_type)
implicit_type_params = sig_parser.implicit_type_params
implicit_generic_params: Set[hir.GenericParameter] = set()
for p in implicit_type_params.values():
assert isinstance(p, hir.SymbolicType)
implicit_generic_params.add(p.param)

def parsing_func(args: hir.FunctionTemplateResolvingArgs) -> hir.FunctionLike:
type_var_ns: Dict[TypeVar, hir.Type | hir.ComptimeValue] = {}
type_var_ns: Dict[TypeVar, hir.Type |
hir.ComptimeValue] = foreign_type_var_ns.copy()
mapped_implicit_type_params: Dict[str,
hir.Type] = dict()
if is_generic:
Expand Down Expand Up @@ -206,10 +206,14 @@ def parsing_func(args: hir.FunctionTemplateResolvingArgs) -> hir.FunctionLike:
return func_parser.parse_body()
params = [v[0] for v in func_sig.args]
is_generic = len(func_sig_converted.generic_params) > 0
# print(f"func {func_name} is_generic: {is_generic}")
# print(
# f"func {func_name} is_generic: {is_generic} {func_sig_converted.generic_params}")
return hir.FunctionTemplate(func_name, params, parsing_func, is_generic)


_TT = TypeVar('_TT')


def _dsl_func_impl(f: _TT, kind: _ObjKind, attrs: Dict[str, Any]) -> _TT:
import sourceinspect
assert inspect.isfunction(f), f"{f} is not a function"
Expand All @@ -220,7 +224,7 @@ def _dsl_func_impl(f: _TT, kind: _ObjKind, attrs: Dict[str, Any]) -> _TT:
func_globals: Any = getattr(f, "__globals__", {})

if kind == _ObjKind.FUNC:
template = _make_func_template(f, func_name, func_globals)
template = _make_func_template(f, func_name, func_globals, {})
ctx.functions[f] = template
setattr(f, "__luisa_func__", template)
return typing.cast(_TT, f)
Expand All @@ -236,32 +240,73 @@ def _dsl_struct_impl(cls: type[_TT], attrs: Dict[str, Any]) -> type[_TT]:
cls_info = class_typeinfo(cls)
globalns = _get_cls_globalns(cls)
globalns[cls.__name__] = cls

def get_ir_type(var_ty: VarType) -> hir.Type:
if isinstance(var_ty, (UnionType, classinfo.AnyType, classinfo.SelfType)):
raise RuntimeError("Struct fields cannot be UnionType")
if isinstance(var_ty, TypeVar):
raise NotImplementedError()
if isinstance(var_ty, GenericInstance):
type_var_to_generic_param: Dict[TypeVar, hir.GenericParameter] = {}
for type_var in cls_info.type_vars:
type_var_to_generic_param[type_var] = hir.GenericParameter(
type_var.__name__, cls.__qualname__)

def parse_fields(tp: parse.TypeParser, self_ty: hir.Type):
fields: List[Tuple[str, hir.Type]] = []
for name, field in cls_info.fields.items():
field_ty = tp.parse_type(field)
if field_ty is None:
raise hir.TypeInferenceError(
None, f"Cannot infer type for field {name} of {cls.__name__}")
fields.append((name, field_ty))
if isinstance(self_ty, hir.StructType):
self_ty.fields = fields
elif isinstance(self_ty, hir.BoundType):
assert isinstance(self_ty.instantiated, hir.StructType)
self_ty.instantiated.fields = fields
else:
raise NotImplementedError()
return ctx.types[var_ty]

fields: List[Tuple[str, hir.Type]] = []
for name, field in cls_info.fields.items():
fields.append((name, get_ir_type(field)))
ir_ty = hir.StructType(
f'{cls.__name__}_{unique_hash(cls.__qualname__)}', cls.__qualname__, fields)
def parse_methods(type_var_ns: Dict[TypeVar, hir.Type | Any], self_ty: hir.Type):
for name in cls_info.methods:
method_object = getattr(cls, name)
template = _make_func_template(
method_object, get_full_name(method_object), globalns, type_var_ns, self_type=self_ty)
if isinstance(self_ty, hir.BoundType):
assert isinstance(self_ty.instantiated, hir.StructType)
self_ty.instantiated.methods[name] = template
else:
self_ty.methods[name] = template

ir_ty: hir.Type = hir.StructType(
f'{cls.__name__}_{unique_hash(cls.__qualname__)}', cls.__qualname__, [])
type_parser = parse.TypeParser(
cls.__qualname__, globalns, {}, ir_ty, 'parse')

parse_fields(type_parser, ir_ty)
is_generic = len(cls_info.type_vars) > 0
if is_generic:
def monomorphization_func(args: List[hir.Type | Any]) -> hir.Type:
assert isinstance(ir_ty, hir.ParametricType)
type_var_ns = {}
if len(args) != len(cls_info.type_vars):
raise hir.TypeInferenceError(
None, f"Expected {len(cls_info.type_vars)} type arguments but got {len(args)}")
for i, arg in enumerate(args):
type_var_ns[cls_info.type_vars[i]] = arg
hash_s = unique_hash(f'{cls.__qualname__}_{args}')
inner_ty = hir.StructType(
f'{cls.__name__}_{hash_s}M', f'{cls.__qualname__}[{",".join([str(a) for a in args])}]', [])
mono_self_ty = hir.BoundType(ir_ty, args, inner_ty)
mono_type_parser = parse.TypeParser(
cls.__qualname__, globalns, type_var_ns, mono_self_ty, 'instantiate')
parse_fields(mono_type_parser, mono_self_ty)
parse_methods(type_var_ns, mono_self_ty)
return inner_ty
ir_ty = hir.ParametricType(
list(type_var_to_generic_param.values()), ir_ty, monomorphization_func)
else:
pass
ctx.types[cls] = ir_ty

for name, method in cls_info.methods.items():
method_object = getattr(cls, name)
template = _make_func_template(
method_object, get_full_name(method_object), globalns, self_type=ir_ty)
ir_ty.methods[name] = template
if not is_generic:
parse_methods({},ir_ty)
return cls



def _dsl_decorator_impl(obj: _TT, kind: _ObjKind, attrs: Dict[str, Any]) -> _TT:
if kind == _ObjKind.STRUCT:
assert isinstance(obj, type), f"{obj} is not a type"
Expand All @@ -288,8 +333,10 @@ def volume(self) -> float:
"""
return _dsl_decorator_impl(cls, _ObjKind.STRUCT, {})


_KernelType = TypeVar("_KernelType", bound=Callable[..., None])


@overload
def kernel(f: _KernelType) -> _KernelType: ...

Expand Down Expand Up @@ -353,4 +400,4 @@ def impl(f: _F) -> _F:
def decorator(f):
return impl(f)

return decorator
return decorator
40 changes: 32 additions & 8 deletions luisa_lang/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from luisa_lang import hir
from luisa_lang.utils import unique_hash, unwrap
from luisa_lang.codegen import CodeGen, ScratchBuffer
from typing import Any, Callable, Dict, Set, Tuple, Union
from typing import Any, Callable, Dict, Optional, Set, Tuple, Union

from luisa_lang.hir import get_dsl_func

Expand Down Expand Up @@ -54,6 +54,13 @@ def do():
self.impl.writeln('};')
return name
return do()
case hir.BoundType():
assert ty.instantiated
return self.gen(ty.instantiated)
case hir.FunctionType():
return ''
case hir.TypeConstructorType():
return ''
case _:
raise NotImplementedError(f"unsupported type: {ty}")

Expand Down Expand Up @@ -144,6 +151,9 @@ def mangle_impl(self, obj: Union[hir.Type, hir.FunctionLike]) -> str:
case hir.TupleType():
elements = [self.mangle(e) for e in obj.elements]
return f"T{unique_hash(''.join(elements))}"
case hir.BoundType():
assert obj.instantiated
return self.mangle(obj.instantiated)
case _:
raise NotImplementedError(f"unsupported object: {obj}")

Expand Down Expand Up @@ -300,6 +310,8 @@ def impl() -> None:
ty = self.base.type_cache.gen(expr.type)
self.body.writeln(
f"{ty} v{vid}{{ {','.join(self.gen_expr(e) for e in expr.args)} }};")
case hir.TypeValue():
pass
case _:
raise NotImplementedError(
f"unsupported expression: {expr}")
Expand All @@ -308,7 +320,7 @@ def impl() -> None:
self.node_map[expr] = f'v{vid}'
return f'v{vid}'

def gen_node(self, node: hir.Node):
def gen_node(self, node: hir.Node) -> Optional[hir.BasicBlock]:

match node:
case hir.Return() as ret:
Expand All @@ -331,7 +343,7 @@ def gen_node(self, node: hir.Node):
self.body.indent += 1
self.gen_bb(if_stmt.else_body)
self.body.indent -= 1
self.gen_bb(if_stmt.merge)
return if_stmt.merge
case hir.Break():
self.body.writeln("__loop_break = true; break;")
case hir.Continue():
Expand All @@ -349,7 +361,7 @@ def gen_node(self, node: hir.Node):
if (loop_break) break;
update();
}
"""
self.body.writeln("while(true) {")
self.body.indent += 1
Expand All @@ -368,6 +380,7 @@ def gen_node(self, node: hir.Node):
self.gen_bb(loop.update)
self.body.indent -= 1
self.body.writeln("}")
return loop.merge
case hir.Alloca() as alloca:
vid = self.new_vid()
assert alloca.type
Expand All @@ -378,12 +391,23 @@ def gen_node(self, node: hir.Node):
self.gen_expr(node)
case hir.Member() | hir.Index():
pass
return None

def gen_bb(self, bb: hir.BasicBlock):
self.body.writeln(f"{{ // BasicBlock Begin {bb.span}")
for node in bb.nodes:
self.gen_node(node)
self.body.writeln(f"}} // BasicBlock End {bb.span}")

while True:
loop_again = False
old_bb = bb
self.body.writeln(f"// BasicBlock Begin {bb.span}")
for i, node in enumerate(bb.nodes):
if (next := self.gen_node(node)) and next is not None:
assert i == len(bb.nodes) - 1
loop_again = True
bb = next
break
self.body.writeln(f"// BasicBlock End {old_bb.span}")
if not loop_again:
break

def gen_locals(self):
for local in self.func.locals:
Expand Down
Loading

0 comments on commit e3fec07

Please sign in to comment.