Skip to content

Commit

Permalink
fixed loop codegen
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Nov 5, 2024
1 parent 87e7e52 commit d92e550
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 37 deletions.
56 changes: 54 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@ A new Python DSL frontend for LuisaCompute. Will be integrated into LuisaCompute
## Content
- [Introduction](#introduction)
- [Basics](#basic-syntax)
- [Difference from Python](#difference-from-python)
- [Types](#types)
- [Value & Reference Semantics](#value--reference-semantics)
- [Functions](#functions)
- [User-defined Structs](#user-defined-structs)
- [Control Flow](#control-flow)
- [Advanced Usage](#advanced-syntax)
- [Generics](#generics)
- [Metaprogramming](#metaprogramming)
Expand All @@ -20,11 +23,16 @@ A new Python DSL frontend for LuisaCompute. Will be integrated into LuisaCompute
import luisa_lang as lc
```
## Basic Syntax
### Difference from Python
There are some notable differences between luisa_lang and Python:
- Variables have value semantics by default. Use `inout` to indicate that an argument that is passed by reference.
- Generic functions and structs are implemented via monomorphization (a.k.a instantiation) at compile time rather than via type erasure.
- Overloading subscript operator and attribute access is different from Python. Only `__getitem__` and `__getattr__` are needed, which returns a local reference.

### Types
```python
```


### Functions
Functions are defined using the `@lc.func` decorator. The function body can contain any valid LuisaCompute code. You can also include normal Python code that will be executed at DSL comile time using `lc.comptime()`. (See [Metaprogramming](#metaprogramming) for more details)

Expand All @@ -37,13 +45,57 @@ def add(a: lc.float, b: lc.float) -> lc.float:

```

LuisaCompute uses value semantics, which means that all types are passed by value. You can use `inout` to indicate that a variable can be modified in place.

### Value & Reference Semantics
Variables have value semantics by default. This means that when you assign a variable to another, a copy is made.
```python
a = lc.float3(1.0, 2.0, 3.0)
b = a
a.x = 2.0
lc.print(f'{a.x} {b.x}') # prints 2.0 1.0
```

You can use `inout` to indicate that a variable is passed as a *local reference*. Assigning to an `inout` variable will update the original variable.
```python
@luisa.func(a=inout, b=inout)
def swap(a: int, b: int):
a, b = b, a

a = lc.float3(1.0, 2.0, 3.0)
b = lc.float3(4.0, 5.0, 6.0)
swap(a.x, b.x)
lc.print(f'{a.x} {b.x}') # prints 4.0 1.0
```

When overloading subscript operator or attribute access, you actually return a local reference to the object.

#### Local References
Local references are like pointers in C++. However, they cannot escape the expression boundary. This means that you cannot store a local reference in a variable and use it later. While you can return a local reference from a function, it must be returned from a uniform path. That is you cannot return different local references based on a condition.


```python
@lc.struct
class InfiniteArray:
def __getitem__(self, index: int) -> int:
return self.data[index] # returns a local reference

# this method will be ignored by the compiler. but you can still put it here for linting
def __setitem__(self, index: int, value: int):
pass

# Not allowed, non-uniform return
def __getitem__(self, index: int) -> int:
if index == 0:
return self.data[0]
else:
return self.data[1]

```





### User-defined Structs
```python
@lc.struct
Expand Down
53 changes: 36 additions & 17 deletions luisa_lang/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ def mangle_impl(self, obj: Union[hir.Type, hir.FunctionLike]) -> str:
case hir.Function(name=name, params=params, return_type=ret):
assert ret
name = mangle_name(name)
params = list(filter(lambda p: not isinstance(
p.type, (hir.FunctionType)), params))
return f'{name}_' + unique_hash(f"F{name}_{self.mangle(ret)}{''.join(self.mangle(unwrap(p.type)) for p in params)}")
case hir.BuiltinFunction(name=name):
name = map_builtin_to_cpp_func(name)
Expand Down Expand Up @@ -203,7 +205,8 @@ def __init__(self, base: CppCodeGen, func: hir.Function) -> None:
self.base = base
self.name = base.mangling.mangle(func)
self.func = func
params = ",".join(self.gen_var(p) for p in func.params)
params = ",".join(self.gen_var(
p) for p in func.params if not isinstance(p.type, hir.FunctionType))
assert func.return_type
self.signature = f'extern "C" auto {self.name}({params}) -> {base.type_cache.gen(func.return_type)}'
self.body = ScratchBuffer()
Expand Down Expand Up @@ -250,6 +253,8 @@ def gen_value_or_ref(self, value: hir.Value | hir.Ref) -> str:
f"unsupported value or reference: {value}")

def gen_expr(self, expr: hir.Value) -> str:
if expr.type and isinstance(expr.type, hir.FunctionType):
return ''
if expr in self.node_map:
return self.node_map[expr]
vid = self.new_vid()
Expand All @@ -269,8 +274,10 @@ def impl() -> None:
f"const auto v{vid} = {base}.{member.field};")
case hir.Call() as call:
op = self.gen_func(call.op)
args_s = ','.join(self.gen_value_or_ref(
arg) for arg in call.args if not isinstance(arg.type, hir.FunctionType))
self.body.writeln(
f"auto v{vid} ={op}({','.join(self.gen_value_or_ref(arg) for arg in call.args)});")
f"auto v{vid} ={op}({args_s});")
case hir.Constant() as constant:
value = constant.value
if isinstance(value, int):
Expand Down Expand Up @@ -302,6 +309,7 @@ def impl() -> None:
return f'v{vid}'

def gen_node(self, node: hir.Node):

match node:
case hir.Return() as ret:
if ret.value:
Expand All @@ -324,31 +332,42 @@ def gen_node(self, node: hir.Node):
self.gen_bb(if_stmt.else_body)
self.body.indent -= 1
self.gen_bb(if_stmt.merge)
case hir.Break():
self.body.writeln("__loop_break = true; break;")
case hir.Continue():
self.body.writeln("break;")
case hir.Loop() as loop:
vid = self.new_vid()
self.body.write(f"auto loop{vid}_prepare = [&]()->bool {{")
"""
while(true) {
bool loop_break = false;
prepare();
if (!cond()) break;
do {
// break => { loop_break = true; break; }
// continue => { break; }
} while(false);
if (loop_break) break;
update();
}
"""
self.body.writeln("while(true) {")
self.body.indent += 1
self.body.writeln("bool __loop_break = false;")
self.gen_bb(loop.prepare)
if loop.cond:
self.body.writeln(f"return {self.gen_expr(loop.cond)};")
else:
self.body.writeln("return true;")
self.body.indent -= 1
self.body.writeln("};")
self.body.writeln(f"auto loop{vid}_body = [&]() {{")
cond = self.gen_expr(loop.cond)
self.body.writeln(f"if (!{cond}) break;")
self.body.writeln("do {")
self.body.indent += 1
self.gen_bb(loop.body)
self.body.indent -= 1
self.body.writeln("};")
self.body.writeln(f"auto loop{vid}_update = [&]() {{")
self.body.indent += 1
self.body.writeln("} while(false);")
self.body.writeln("if (__loop_break) break;")
if loop.update:
self.gen_bb(loop.update)
self.body.indent -= 1
self.body.writeln("};")
self.body.writeln(
f"for(;loop{vid}_prepare();loop{vid}_update());")
self.gen_bb(loop.merge)
self.body.writeln("}")
case hir.Alloca() as alloca:
vid = self.new_vid()
assert alloca.type
Expand Down
16 changes: 9 additions & 7 deletions luisa_lang/hir.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def method(self, name: str) -> Optional[FunctionLike | FunctionTemplate]:

def is_concrete(self) -> bool:
return True

def __len__(self) -> int:
return 1

Expand Down Expand Up @@ -341,7 +341,8 @@ def member(self, field: Any) -> Optional['Type']:

def __len__(self) -> int:
return self.count



class ArrayType(Type):
element: Type
count: Union[int, "SymbolicConstant"]
Expand Down Expand Up @@ -868,6 +869,7 @@ def __init__(self, args: List[Value], type: Type, span: Optional[Span] = None) -
super().__init__(type, span)
self.args = args


class Call(Value):
op: FunctionLike
"""After type inference, op should be a Value."""
Expand Down Expand Up @@ -988,17 +990,17 @@ def __init__(


class Break(Terminator):
target: Loop
target: Loop | None

def __init__(self, target: Loop, span: Optional[Span] = None) -> None:
def __init__(self, target: Loop | None, span: Optional[Span] = None) -> None:
super().__init__(span)
self.target = target


class Continue(Terminator):
target: Loop
target: Loop | None

def __init__(self, target: Loop, span: Optional[Span] = None) -> None:
def __init__(self, target: Loop | None, span: Optional[Span] = None) -> None:
super().__init__(span)
self.target = target

Expand Down Expand Up @@ -1057,7 +1059,7 @@ def update(self, value: Any) -> None:
self.update_func(value)
else:
raise RuntimeError("unable to update comptime value")

def __str__(self) -> str:
return f"ComptimeValue({self.value})"

Expand Down
52 changes: 41 additions & 11 deletions luisa_lang/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ class FuncParser:
type_var_ns: Dict[typing.TypeVar, hir.Type | ComptimeValue]
bb_stack: List[hir.BasicBlock]
type_parser: TypeParser
break_and_continues: List[hir.Break | hir.Continue] | None

def __init__(self, name: str,
func: object,
Expand All @@ -173,7 +174,7 @@ def __init__(self, name: str,
self.parsed_func = hir.Function(name, [], None)
self.type_var_ns = type_var_ns
self.bb_stack = []

self.break_and_continues = None
self.parsed_func.params = signature.params
for p in self.parsed_func.params:
self.vars[p.name] = p
Expand Down Expand Up @@ -262,11 +263,12 @@ def parse_name(self, name: ast.Name, new_var_hint: NewVarHint) -> hir.Ref | hir.
if name.id in self.globalns:
resolved = self.globalns[name.id]
return self.convert_any_to_value(resolved, span)
elif name.id in __builtins__: # type: ignore
elif name.id in __builtins__: # type: ignore
resolved = __builtins__[name.id] # type: ignore
return self.convert_any_to_value(resolved, span)
elif new_var_hint == 'comptime':
self.globalns[name.id] = None

def update_fn(value: Any) -> None:
self.globalns[name.id] = value
return ComptimeValue(None, update_fn)
Expand Down Expand Up @@ -379,7 +381,9 @@ def parse_call_impl(self, span: hir.Span | None, f: hir.FunctionLike | hir.Funct
span,
f"Expected {len(template_params)} arguments, got {len(args)}")
for i, (param, arg) in enumerate(zip(template_params, args)):
assert arg.type is not None
if arg.type is None:
raise hir.TypeInferenceError(
span, f"failed to infer type of argument {i}")
template_resolve_args.append((param, arg.type))
resolved_f = f.resolve(template_resolve_args)
if isinstance(resolved_f, hir.TemplateMatchingError):
Expand Down Expand Up @@ -467,6 +471,7 @@ def handle_range() -> hir.Value | ComptimeValue:
args[i] = self.try_convert_comptime_value(
arg, hir.Span.from_ast(expr.args[i]))
converted_args = cast(List[hir.Value], args)

def make_int(i: int) -> hir.Value:
return hir.Constant(i, type=hir.GenericIntType())
if len(args) == 1:
Expand Down Expand Up @@ -516,10 +521,12 @@ def collect_args() -> List[hir.Value | hir.Ref]:
raise hir.ParsingError(expr, call.message)
assert isinstance(call, hir.Call)
return self.cur_bb().append(hir.Load(tmp))

if not isinstance(func, hir.Constant) or not isinstance(func.value, (hir.Function, hir.BuiltinFunction, hir.FunctionTemplate)):
if func.type is not None and isinstance(func.type, hir.FunctionType):
func_like = func.type.func_like
elif not isinstance(func, hir.Constant) or not isinstance(func.value, (hir.Function, hir.BuiltinFunction, hir.FunctionTemplate)):
raise hir.ParsingError(expr, f"function expected")
func_like = func.value
else:
func_like = func.value
ret = self.parse_call_impl(
hir.Span.from_ast(expr), func_like, collect_args())
if isinstance(ret, hir.TemplateMatchingError):
Expand Down Expand Up @@ -791,13 +798,19 @@ def parse_stmt(self, stmt: ast.stmt) -> None:
stmt, "while loop condition must not be a comptime value")
body = hir.BasicBlock(span)
self.bb_stack.append(body)
old_break_and_continues = self.break_and_continues
self.break_and_continues = []
for s in stmt.body:
self.parse_stmt(s)
break_and_continues = self.break_and_continues
self.break_and_continues = old_break_and_continues
body = self.bb_stack.pop()
update = hir.BasicBlock(span)
merge = hir.BasicBlock(span)
pred_bb.append(
hir.Loop(prepare, cond, body, update, merge, span))
loop_node = hir.Loop(prepare, cond, body, update, merge, span)
pred_bb.append(loop_node)
for bc in break_and_continues:
bc.target = loop_node
self.bb_stack.append(merge)
case ast.For():
iter_val = self.parse_expr(stmt.iter)
Expand Down Expand Up @@ -828,12 +841,16 @@ def parse_stmt(self, stmt: ast.stmt) -> None:
self.bb_stack.pop()
body = hir.BasicBlock(span)
self.bb_stack.append(body)
old_break_and_continues = self.break_and_continues
self.break_and_continues = []
for s in stmt.body:
self.parse_stmt(s)
body = self.bb_stack.pop()
break_and_continues = self.break_and_continues
self.break_and_continues = old_break_and_continues
update = hir.BasicBlock(span)
self.bb_stack.append(update)
inc =loop_range.step
inc = loop_range.step
int_add = loop_var.type.method("__add__")
assert int_add is not None
add = self.parse_call_impl(
Expand All @@ -842,9 +859,22 @@ def parse_stmt(self, stmt: ast.stmt) -> None:
self.cur_bb().append(hir.Assign(loop_var, add))
self.bb_stack.pop()
merge = hir.BasicBlock(span)
pred_bb.append(
hir.Loop(prepare, cmp_result, body, update, merge, span))
loop_node = hir.Loop(prepare, cmp_result,
body, update, merge, span)
pred_bb.append(loop_node)
for bc in break_and_continues:
bc.target = loop_node
self.bb_stack.append(merge)
case ast.Break():
if self.break_and_continues is None:
raise hir.ParsingError(
stmt, "break statement must be inside a loop")
self.cur_bb().append(hir.Break(None, span))
case ast.Continue():
if self.break_and_continues is None:
raise hir.ParsingError(
stmt, "continue statement must be inside a loop")
self.cur_bb().append(hir.Continue(None, span))
case ast.Return():
def check_return_type(ty: hir.Type) -> None:
assert self.parsed_func
Expand Down

0 comments on commit d92e550

Please sign in to comment.