Skip to content

Commit

Permalink
while loop and multi assignment
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Nov 5, 2024
1 parent da68f83 commit b51cfb6
Show file tree
Hide file tree
Showing 5 changed files with 254 additions and 74 deletions.
23 changes: 20 additions & 3 deletions luisa_lang/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@ def gen_impl(self, ty: hir.Type) -> str:
return name
case hir.UnitType():
return 'void'
case hir.TupleType():
def do():
elements = [self.gen(e) for e in ty.elements]
name = f'Tuple_{unique_hash("".join(elements))}'
self.impl.writeln(f'struct {name} {{')
for i, element in enumerate(elements):
self.impl.writeln(f' {element} _{i};')
self.impl.writeln('};')
return name
return do()
case _:
raise NotImplementedError(f"unsupported type: {ty}")

Expand Down Expand Up @@ -129,6 +139,9 @@ def mangle_impl(self, obj: Union[hir.Type, hir.FunctionLike]) -> str:
return f"__builtin_{name}"
case hir.StructType(name=name):
return name
case hir.TupleType():
elements = [self.mangle(e) for e in obj.elements]
return f"T{unique_hash(''.join(elements))}"
case _:
raise NotImplementedError(f"unsupported object: {obj}")

Expand Down Expand Up @@ -275,6 +288,10 @@ def impl() -> None:
else:
raise NotImplementedError(
f"unsupported constant: {constant}")
case hir.AggregateInit():
assert expr.type
ty = self.base.type_cache.gen(expr.type)
self.body.writeln(f"{ty} v{vid}{{}};")
case _:
raise NotImplementedError(
f"unsupported expression: {expr}")
Expand Down Expand Up @@ -310,12 +327,12 @@ def gen_node(self, node: hir.Node):
vid = self.new_vid()
self.body.write(f"auto loop{vid}_prepare = [&]()->bool {{")
self.body.indent += 1
self.gen_bb(loop.prepare)
self.gen_bb(loop.prepare)
if loop.cond:
self.body.writeln(f"return {self.gen_expr(loop.cond)};")
else:
self.body.writeln("return true;")
self.body.indent -=1
self.body.indent -= 1
self.body.writeln("};")
self.body.writeln(f"auto loop{vid}_body = [&]() {{")
self.body.indent += 1
Expand Down Expand Up @@ -354,7 +371,7 @@ def gen_locals(self):
continue
assert (
local.type
), f"Local variable {local.name} contains unresolved type, please resolve it via TypeInferencer"
), f"Local variable `{local.name}` contains unresolved type"
self.body.writeln(
f"{self.base.type_cache.gen(local.type)} {local.name}{{}};"
)
Expand Down
16 changes: 13 additions & 3 deletions luisa_lang/hir.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
Tuple,
Dict,
Union,
cast,
)
import typing
from typing_extensions import override
Expand Down Expand Up @@ -128,6 +127,9 @@ def method(self, name: str) -> Optional[FunctionLike | FunctionTemplate]:

def is_concrete(self) -> bool:
return True

def __len__(self) -> int:
return 1


class UnitType(Type):
Expand Down Expand Up @@ -337,7 +339,9 @@ def member(self, field: Any) -> Optional['Type']:
return self.element
return Type.member(self, field)


def __len__(self) -> int:
return self.count

class ArrayType(Type):
element: Type
count: Union[int, "SymbolicConstant"]
Expand Down Expand Up @@ -789,7 +793,7 @@ class Index(Value):
index: Value

def __init__(self, base: Value, index: Value, type: Type, span: Optional[Span]) -> None:
super().__init__(None, span)
super().__init__(type, span)
self.base = base
self.index = index

Expand Down Expand Up @@ -857,6 +861,12 @@ def __init__(self, ty: Type, span: Optional[Span] = None) -> None:
# super().__init__(ty, span)
# self.init_call = init_call

class AggregateInit(Value):
args: List[Value]

def __init__(self, args: List[Value], type: Type, span: Optional[Span] = None) -> None:
super().__init__(type, span)
self.args = args

class Call(Value):
op: FunctionLike
Expand Down
6 changes: 3 additions & 3 deletions luisa_lang/lang.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from luisa_lang.classinfo import VarType, GenericInstance, UnionType, _get_cls_globalns, register_class, class_typeinfo
from enum import Enum, auto
from typing_extensions import TypeAliasType
import typing
from typing import (
Callable,
Dict,
Expand All @@ -14,7 +15,6 @@
Union,
Generic,
Literal,
cast,
overload,
Any,
)
Expand Down Expand Up @@ -109,7 +109,7 @@ def _dsl_func_impl(f: _T, kind: _ObjKind, attrs: Dict[str, Any]) -> _T:
template = _make_func_template(f, func_name, func_globals)
ctx.functions[f] = template
setattr(f, "__luisa_func__", template)
return cast(_T, f)
return typing.cast(_T, f)
else:
raise NotImplementedError()
# return cast(_T, f)
Expand Down Expand Up @@ -150,7 +150,7 @@ def get_ir_type(var_ty: VarType) -> hir.Type:
def _dsl_decorator_impl(obj: _T, kind: _ObjKind, attrs: Dict[str, Any]) -> _T:
if kind == _ObjKind.STRUCT:
assert isinstance(obj, type), f"{obj} is not a type"
return cast(_T, _dsl_struct_impl(obj, attrs))
return typing.cast(_T, _dsl_struct_impl(obj, attrs))
elif kind == _ObjKind.FUNC or kind == _ObjKind.KERNEL:
return _dsl_func_impl(obj, kind, attrs)
raise NotImplementedError()
Expand Down
5 changes: 4 additions & 1 deletion luisa_lang/lang_builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def block_id() -> uint3:


@_builtin
def convert(target: type[_T], value: Any) -> _T:
def cast(target: type[_T], value: Any) -> _T:
"""
Attempt to convert the value to the target type.
"""
Expand Down Expand Up @@ -185,4 +185,7 @@ def value(self, value: _T) -> None:
'static_assert',
'type_of_opt',
'typeof',
"dispatch_id",
"thread_id",
"block_id",
]
Loading

0 comments on commit b51cfb6

Please sign in to comment.