diff --git a/pano/vm.py b/pano/vm.py index 1d92ad43..0da47c72 100644 --- a/pano/vm.py +++ b/pano/vm.py @@ -438,7 +438,7 @@ def handle_jumps(self, trace, line, condition): trace.append(("jump", n)) return trace - if op == "jumpi": + elif op == "jumpi": target = stack.pop() if_condition = simplify_bool(stack.pop()) @@ -492,19 +492,7 @@ def handle_jumps(self, trace, line, condition): logger.debug("jumpi -> if %s", trace[-1]) return trace - if op == "selfdestruct": - trace.append(("selfdestruct", stack.pop(),)) - return trace - - if op in ["stop", "assert_fail", "invalid"]: - trace.append((op,)) - return trace - - if op == "UNKNOWN": - trace.append(("invalid",)) - return trace - - if op in ["return", "revert"]: + elif op in ["return", "revert"]: p = stack.pop() n = stack.pop() @@ -516,6 +504,18 @@ def handle_jumps(self, trace, line, condition): return trace + elif op in ["stop", "assert_fail", "invalid"]: + trace.append((op,)) + return trace + + elif op == "UNKNOWN": + trace.append(("invalid",)) + return trace + + elif op == "selfdestruct": + trace.append(("selfdestruct", stack.pop(),)) + return trace + return None def apply_stack(self, ret, line): @@ -550,16 +550,6 @@ def trace(exp, *format_args): else: trace("[{}] {} {}", line[0], C.asm(op), C.asm(str(line[2]))) - assert op not in [ - "jump", - "jumpi", - "revert", - "return", - "stop", - "jumpdest", - "UNKNOWN", - ] - param = 0 if len(line) > 2: param = line[2] @@ -581,16 +571,37 @@ def trace(exp, *format_args): ]: stack.append(arithmetic.eval((op, stack.pop(), stack.pop(),))) - if op in ["mulmod", "addmod"]: - stack.append(("mulmod", stack.pop(), stack.pop(), stack.pop())) + elif op[:4] == "push": + stack.append(param) + + elif op == "pop": + stack.pop() - if op == "mul": + elif op == "dup": + stack.dup(param) + + elif op == "mul": stack.append(mul_op(stack.pop(), stack.pop())) - if op == "or": + elif op == "or": stack.append(or_op(stack.pop(), stack.pop())) - if op == "shl": + elif op == "add": + stack.append(add_op(stack.pop(), stack.pop())) + + elif op == "sub": + left = stack.pop() + right = stack.pop() + + if type(left) == int and type(right) == int: + stack.append(arithmetic.sub(left, right)) + else: + stack.append(sub_op(left, right)) + + elif op in ["not", "iszero"]: + stack.append((op, stack.pop())) + + elif op == "shl": off = stack.pop() exp = stack.pop() if all_concrete(off, exp): @@ -598,7 +609,7 @@ def trace(exp, *format_args): else: stack.append(mask_op(exp, shl=off)) - if op == "shr": + elif op == "shr": off = stack.pop() exp = stack.pop() if all_concrete(off, exp): @@ -606,7 +617,7 @@ def trace(exp, *format_args): else: stack.append(mask_op(exp, offset=minus_op(off), shr=off)) - if op == "sar": + elif op == "sar": off = stack.pop() exp = stack.pop() if all_concrete(off, exp): @@ -625,20 +636,25 @@ def trace(exp, *format_args): # FIXME: This won't give the right result... stack.append(mask_op(exp, offset=minus_op(off), shr=off)) - if op == "add": - stack.append(add_op(stack.pop(), stack.pop())) + elif op == "mstore": + memloc = stack.pop() + val = stack.pop() + trace(("setmem", ("range", memloc, 32), val,)) - if op == "sub": - left = stack.pop() - right = stack.pop() + elif op == "msize": + self.counter += 1 + vname = f"_{self.counter}" + trace(("setvar", vname, "msize")) + stack.append(("var", vname)) - if type(left) == int and type(right) == int: - stack.append(arithmetic.sub(left, right)) - else: - stack.append(sub_op(left, right)) + elif op == "mload": + memloc = stack.pop() + loaded = mem_load(memloc) - elif op in ["not", "iszero"]: - stack.append((op, stack.pop())) + self.counter += 1 + vname = f"_{self.counter}" + trace(("setvar", vname, ("mem", ("range", memloc, 32)))) + stack.append(("var", vname)) elif op == "sha3": p = stack.pop() @@ -664,9 +680,6 @@ def trace(exp, *format_args): off = sub_op(256, to_bytes(num)) stack.append(mask_op(val, 8, off, shr=off)) - elif op == "selfbalance": - stack.append(("balance", "address",)) - elif op == "balance": addr = stack.pop() if opcode(addr) == "mask_shl" and addr[:4] == ("mask_shl", 160, 0, 0): @@ -674,9 +687,30 @@ def trace(exp, *format_args): else: stack.append(("balance", addr,)) + elif op in [ + "callvalue", + "caller", + "address", + "number", + "gas", + "origin", + "timestamp", + "chainid", + "difficulty", + "gasprice", + "coinbase", + "gaslimit", + "calldatasize", + "returndatasize", + ]: + stack.append(op) + elif op == "swap": stack.swap(param) + elif op == "selfbalance": + stack.append(("balance", "address",)) + elif op[:3] == "log": p = stack.pop() s = stack.pop() @@ -697,26 +731,6 @@ def trace(exp, *format_args): val = stack.pop() trace(("store", 256, 0, sloc, val)) - elif op == "mload": - memloc = stack.pop() - loaded = mem_load(memloc) - - self.counter += 1 - vname = f"_{self.counter}" - trace(("setvar", vname, ("mem", ("range", memloc, 32)))) - stack.append(("var", vname)) - - elif op == "mstore": - memloc = stack.pop() - val = stack.pop() - trace(("setmem", ("range", memloc, 32), val,)) - - elif op == "mstore8": - memloc = stack.pop() - val = stack.pop() - - trace(("setmem", ("range", memloc, 8), val,)) - elif op == "extcodecopy": addr = stack.pop() mem_pos = stack.pop() @@ -896,44 +910,31 @@ def trace(exp, *format_args): stack.append("create2.new_address") - elif op[:4] == "push": - stack.append(param) + elif op in ("extcodesize", "extcodehash", "blockhash"): + stack.append((op, stack.pop(),)) + + elif op in ["mulmod", "addmod"]: + stack.append(("mulmod", stack.pop(), stack.pop(), stack.pop())) elif op == "pc": stack.append(line[0]) - elif op == "pop": - stack.pop() - - elif op == "dup": - stack.dup(param) - - elif op == "msize": - self.counter += 1 - vname = f"_{self.counter}" - trace(("setvar", vname, "msize")) - stack.append(("var", vname)) + elif op == "mstore8": + memloc = stack.pop() + val = stack.pop() - elif op in ("extcodesize", "extcodehash", "blockhash"): - stack.append((op, stack.pop(),)) + trace(("setmem", ("range", memloc, 8), val,)) - elif op in [ - "callvalue", - "caller", - "address", - "number", - "gas", - "origin", - "timestamp", - "chainid", - "difficulty", - "gasprice", - "coinbase", - "gaslimit", - "calldatasize", - "returndatasize", - ]: - stack.append(op) + else: + assert op not in [ + "jump", + "jumpi", + "revert", + "return", + "stop", + "jumpdest", + "UNKNOWN", + ] if stack.len() - previous_len != opcode_dict.stack_diffs[op]: logger.error("line: %s", line)