Skip to content

Commit

Permalink
added comparison and unary op
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Nov 5, 2024
1 parent 1646db0 commit 699aff2
Showing 1 changed file with 63 additions and 42 deletions.
105 changes: 63 additions & 42 deletions luisa_lang/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,24 @@ def collect_args() -> List[hir.Value | hir.Ref]:
raise hir.ParsingError(expr, ret.message)
return ret

def parse_binop(self, expr: ast.BinOp) -> hir.Value:
# def parse_compare(self, expr: ast.Compare) -> hir.Value | ComptimeValue:
# cmpop_to_str: Dict[type, str] = {
# ast.Eq: "==",
# ast.NotEq: "!=",
# ast.Lt: "<",
# ast.LtE: "<=",
# ast.Gt: ">",
# ast.GtE: ">="
# }
# if len(expr.ops) != 1:
# raise hir.ParsingError(expr, "only one comparison operator is allowed")
# op = expr.ops[0]
# if type(op) not in cmpop_to_str:
# raise hir.ParsingError(expr, f"unsupported comparison operator {type(op)}")
# op_str = cmpop_to_str[type(op)]
# method_name = BINOP_TO_METHOD_NAMES[type(op)]

def parse_binop(self, expr: ast.BinOp | ast.Compare) -> hir.Value:
binop_to_op_str: Dict[type, str] = {
ast.Add: "+",
ast.Sub: "-",
Expand All @@ -556,20 +573,32 @@ def parse_binop(self, expr: ast.BinOp) -> hir.Value:
ast.GtE: ">=",

}
op_str = binop_to_op_str[type(expr.op)]
lhs = self.parse_expr(expr.left)
op: ast.AST
if isinstance(expr, ast.Compare):
if len(expr.ops) != 1:
raise hir.ParsingError(
expr, "only one comparison operator is allowed")
op = expr.ops[0]
left = expr.left
right = expr.comparators[0]
else:
op = expr.op
left = expr.left
right = expr.right
op_str = binop_to_op_str[type(op)]
lhs = self.parse_expr(left)
if isinstance(lhs, ComptimeValue):
lhs = self.try_convert_comptime_value(lhs, hir.Span.from_ast(expr))
if not lhs.type:
raise hir.ParsingError(
expr.left, f"unable to infer type of left operand of binary operation {op_str}")
rhs = self.parse_expr(expr.right)
left, f"unable to infer type of left operand of binary operation {op_str}")
rhs = self.parse_expr(right)
if isinstance(rhs, ComptimeValue):
rhs = self.try_convert_comptime_value(rhs, hir.Span.from_ast(expr))
if not rhs.type:
raise hir.ParsingError(
expr.right, f"unable to infer type of right operand of binary operation {op_str}")
ops = BINOP_TO_METHOD_NAMES[type(expr.op)]
right, f"unable to infer type of right operand of binary operation {op_str}")
ops = BINOP_TO_METHOD_NAMES[type(op)]

def infer_binop(name: str, rname: str) -> hir.Value:
assert lhs.type and rhs.type
Expand Down Expand Up @@ -712,6 +741,30 @@ def check(i: int, val_type: hir.Type) -> None:
raise hir.ParsingError(
targets[0], f"unsupported type for unpacking: {values.type}")

def parse_unary(self, expr: ast.UnaryOp) -> hir.Value:
op = expr.op
if type(op) not in UNARY_OP_TO_METHOD_NAMES:
raise hir.ParsingError(
expr, f"unsupported unary operator {type(op)}")
op_str = UNARY_OP_TO_METHOD_NAMES[type(op)]
operand = self.parse_expr(expr.operand)
if isinstance(operand, ComptimeValue):
operand = self.try_convert_comptime_value(
operand, hir.Span.from_ast(expr))
if not operand.type:
raise hir.ParsingError(
expr.operand, f"unable to infer type of operand of unary operation {op_str}")
method_name = UNARY_OP_TO_METHOD_NAMES[type(op)]
if (method := operand.type.method(method_name)) and method:
ret = self.parse_call_impl(
hir.Span.from_ast(expr), method, [operand])
if isinstance(ret, hir.TemplateMatchingError):
raise hir.ParsingError(expr, ret.message)
return ret
else:
raise hir.ParsingError(
expr, f"operator {type(op)} not defined for type {operand.type}")

def parse_expr(self, expr: ast.expr) -> hir.Value | ComptimeValue:
match expr:
case ast.Constant():
Expand All @@ -723,8 +776,10 @@ def parse_expr(self, expr: ast.expr) -> hir.Value | ComptimeValue:
return ret
case ast.Subscript() | ast.Attribute():
return self.parse_access(expr)
case ast.BinOp():
case ast.BinOp() | ast.Compare():
return self.parse_binop(expr)
case ast.UnaryOp():
return self.parse_unary(expr)
case ast.Call():
return self.parse_call(expr)
case ast.Tuple():
Expand Down Expand Up @@ -970,40 +1025,6 @@ def parse_anno_ty() -> hir.Type:
if stmt.value:
self.parse_multi_assignment(
[stmt.target], [parse_anno_ty], self.parse_expr(stmt.value))
# value = self.parse_expr(stmt.value)
# if isinstance(value, ComptimeValue):
# var = self.parse_ref(
# stmt.target, new_var_hint='comptime')
# else:
# var = self.parse_ref(stmt.target, new_var_hint='dsl')
# if isinstance(var, ComptimeValue):
# if isinstance(value, ComptimeValue):
# try:
# var.update(value.value)
# except Exception as e:
# raise hir.ParsingError(
# stmt, f"error updating comptime value: {e}") from e
# return
# else:
# raise hir.ParsingError(
# stmt, f"comptime value cannot be assigned with DSL value")
# else:
# if isinstance(value, ComptimeValue):
# value = self.try_convert_comptime_value(
# value, span)
# assert value.type
# anno_ty = parse_anno_ty()
# if not var.type:
# var.type = value.type
# if not var.type.is_concrete():
# raise hir.ParsingError(
# stmt, "only concrete type can be assigned, please annotate the variable with concrete types")
# if not hir.is_type_compatible_to(value.type, anno_ty):
# raise hir.ParsingError(
# stmt, f"expected {anno_ty}, got {value.type}")
# if not value.type.is_concrete():
# value.type = var.type
# self.cur_bb().append(hir.Assign(var, value, span))
else:
var = self.parse_ref(stmt.target, new_var_hint='dsl')
anno_ty = parse_anno_ty()
Expand Down

0 comments on commit 699aff2

Please sign in to comment.