diff --git a/docs/collapse_literals.rst b/docs/collapse_literals.rst index f782b9d..a2a3e71 100644 --- a/docs/collapse_literals.rst +++ b/docs/collapse_literals.rst @@ -70,3 +70,4 @@ If the branch is constant, and thus known at decoration time, then this flaw won .. todo:: Support sets? .. todo:: Always commit changes within a block, and only mark values as non-deterministic outside of conditional blocks .. todo:: Support list/set/dict comprehensions +.. todo:: Support known elements of format strings (JoinedStr) in python 3.6+ diff --git a/docs/inline.rst b/docs/inline.rst index e8019a3..8ed65e7 100644 --- a/docs/inline.rst +++ b/docs/inline.rst @@ -6,12 +6,12 @@ Inlining Functions Inline specified functions into the decorated function. Unlike in C, this directive is placed not on the function getting inlined, but rather the function into which it's getting inlined (since that's the one whose code needs to be modified and hence decorated). Currently, this is implemented in the following way: - When a function is called, its call code is placed within the current code block immediately before the line where its value is needed -- The code is wrapped in a one-iteration ``for`` loop (effectively a ``do {} while(0)``), and the ``return`` statement is replaced by a ``break`` +- The code is wrapped in a ``try/except`` block, and the return value is passed back out using a special exception type - Arguments are stored into a dictionary, and variadic keyword arguments are passed as ``dict_name.update(kwargs)``; this dictionary has the name ``_[funcname]`` where ``funcname`` is the name of the function being inlined, so other variables of this name should not be used or relied upon - The return value is assigned to the function name as well, deleting the argument dictionary, freeing its memory, and making the return value usable when the function's code is exited by the ``break`` - The call to the function is replaced by the variable holding the return value -As a result, ``pragma.inline`` cannot currently handle functions which contain a ``return`` statement within a loop. Since Python doesn't support anything like ``goto`` besides wrapping the code in a function (which this function implicitly shouldn't do), I don't know how to surmount this problem. Without much effort, it can be overcome by tailoring the function to be inlined. +As a result, ``pragma.inline`` cannot currently handle functions which contain a ``return`` statement within a bare ``try/except`` or ``except BaseException``. Since Python doesn't support anything like ``goto`` besides wrapping the code in a function (which this function implicitly shouldn't do), I don't know how to surmount this problem. Without much effort, it can be overcome by tailoring the function to be inlined. In general, it's bad practice to use a bare ``except:`` or ``except BaseException:``, and such calls should generally be replaced with ``except Exception:``, which would this issue. To inline a function ``f`` into the code of another function ``g``, use ``pragma.inline(g)(f)``, or, as a decorator:: @@ -23,49 +23,37 @@ To inline a function ``f`` into the code of another function ``g``, use ``pragma z = y + 3 return f(z * 4) - # ... g Becomes something like ... + # ... g Becomes... def g(y): z = y + 3 - _f = dict(x=z * 4) # Store arguments - for ____ in [None]: # Function body - _f['return'] = _f['x'] ** 2 # Store the "return"ed value - break # Return, terminate the function body - _f_return = _f.get('return', None) # Retrieve the returned value - del _f # Discard everything else - return _f_return - -This loop can be removed, if it's not necessary, using :func:``pragma.unroll``. This can be accomplished if there are no returns within a conditional or loop block. In this case:: - - def f(x): - return x**2 - - @pragma.unroll - @pragma.inline(f) - def g(y): - z = y + 3 - return f(z * 4) - - # ... g Becomes ... - - def g(y): - z = y + 3 - _f = {} - _f['x'] = z * 4 - _f = _f['x'] ** 2 - return _f - -It needs to be noted that, besides arguments getting stored into a dictionary, other variable names remain unaltered when inlined. Thus, if there are shared variable names in the two functions, they might overwrite each other in the resulting inlined function. + _f_0 = dict(x=z * 4) + try: # Function body + raise _PRAGMA_INLINE_RETURN(_f_0['x'] ** 2) + except _PRAGMA_INLINE_RETURN as _f_return_0_exc: + _f_return_0 = _f_return_0_exc.return_val + else: + _f_return_0 = None + finally: # Discard artificial stack frame + del _f_0 + return _f_return_0 .. todo:: Fix name collision by name-mangling non-free variables -Eventually, this could be collapsed using :func:``pragma.collapse_literals``, to produce simply ``return ((y + 3) * 4) ** 2``, but dictionaries aren't yet supported for collapsing. +Eventually, this could be collapsed using :func:``pragma.collapse_literals``, to produce simply ``return ((y + 3) * 4) ** 2``, but there are numerous hurtles in the way toward making this happen. -When inlining a generator function, the function's results are collapsed into a list, which is then returned. This will break in two main scenarios: +When inlining a generator function, the function's results are collapsed into a list, which is then returned. This is equivalent to calling ``list(generator_func(*args, **kwargs))``. This will break in two main scenarios: - The generator never ends, or consumes excessive amounts of resources. -- The calling code relies on the resulting generator being more than just iterable. +- The calling code relies on the resulting generator being more than just iterable, e.g. if data is passed back in using calls to ``next``. + +.. todo:: Fix generators to return something more like ``iter(list(f(*args, **kwargs))``, since ``list`` itself is not an iterator, but the return of a generator is. In general, either this won't be an issue, or you should know better than to try to inline the infinite generator. -.. todo:: Support inlining a generator into another generator by merging the functions together. E.g., ``for x in my_range(5): yield x + 2`` becomes ``i = 0; while i < 5: yield i + 2; i += 1`` (or something vaguely like that). \ No newline at end of file +.. todo:: Support inlining a generator into another generator by merging the functions together. E.g., ``for x in my_range(5): yield x + 2`` becomes ``i = 0; while i < 5: yield i + 2; i += 1`` (or something vaguely like that). +.. todo:: Support inlining closures; if the inlined function refers to global or nonlocal variables, import them into the closure of the final function. + +Recursive calls are handled by keeping a counter of the inlined recursion depth, and changing the suffix number of the local variables dictionary (e.g., ``_f_0``). These dictionaries serve as stack frames: their unique naming permits multiple, even stacked, inlined function calls, and their deletion enforces the usual life span of function-local variables. + +.. todo:: Support option to either inline as loop or exception \ No newline at end of file diff --git a/pragma/collapse_literals.py b/pragma/collapse_literals.py index 3809ded..12f7048 100644 --- a/pragma/collapse_literals.py +++ b/pragma/collapse_literals.py @@ -8,9 +8,10 @@ # noinspection PyPep8Naming class CollapseTransformer(TrackedContextTransformer): def visit_Name(self, node): - res = self.resolve_literal(node) - if isinstance(res, primitive_ast_types): - return res + if isinstance(node.ctx, ast.Load): + res = self.resolve_literal(node) + if isinstance(res, primitive_ast_types): + return res return node def visit_BinOp(self, node): diff --git a/pragma/core/transformer.py b/pragma/core/transformer.py index 0b79ded..0b8dced 100644 --- a/pragma/core/transformer.py +++ b/pragma/core/transformer.py @@ -36,21 +36,27 @@ def function_ast(f): class DebugTransformerMixin: # pragma: nocover def visit(self, node): - orig_node_code = astor.to_source(node).strip() - log.debug("Starting to visit >> {} << ({})".format(orig_node_code, type(node))) + cls = type(self).__name__ + + try: + orig_node_code = astor.to_source(node).strip() + except Exception as ex: + log.error("{} ({})".format(type(node), astor.dump_tree(node)), exc_info=ex) + raise ex + log.debug("{} Starting to visit >> {} << ({})".format(cls, orig_node_code, type(node))) new_node = super().visit(node) try: if new_node is None: - log.debug("Deleted >>> {} <<<".format(orig_node_code)) + log.debug("{} Deleted >>> {} <<<".format(cls, orig_node_code)) elif isinstance(new_node, ast.AST): - log.debug("Converted >>> {} <<< to >>> {} <<<".format(orig_node_code, astor.to_source(new_node).strip())) + log.debug("{} Converted >>> {} <<< to >>> {} <<<".format(cls, orig_node_code, astor.to_source(new_node).strip())) elif isinstance(new_node, list): - log.debug("Converted >>> {} <<< to [[[ {} ]]]".format(orig_node_code, ", ".join( + log.debug("{} Converted >>> {} <<< to [[[ {} ]]]".format(cls, orig_node_code, ", ".join( astor.to_source(n).strip() for n in new_node))) except Exception as ex: - log.error("Failed on {} >>> {}".format(orig_node_code, astor.dump_tree(new_node)), exc_info=ex) + log.error("{} Failed on {} >>> {}".format(cls, orig_node_code, astor.dump_tree(new_node)), exc_info=ex) raise ex return new_node @@ -237,11 +243,12 @@ def assign(self, name, val): log.debug("Failed to assign {}={}, rvalue cannot be converted to AST".format(name, val)) def visit_Assign(self, node): - node.value = self.visit(node.value) + node = self.generic_visit(node) self.assign(node.targets, node.value) return node def visit_AugAssign(self, node): + node.target = self.visit(node.target) node = copy.deepcopy(node) node.value = self.visit(node.value) new_val = self.resolve_literal(ast.BinOp(op=node.op, left=node.target, right=node.value)) @@ -292,40 +299,46 @@ def visit_ClassDef(self, node): def visit_For(self, node): node.iter = self.visit(node.iter) + node.target = self.visit(node.target) node.body = self.nested_visit(node.body) node.orelse = self.nested_visit(node.orelse) - return self.generic_visit_less(node, 'body', 'orelse', 'iter') + return node def visit_AsyncFor(self, node): node.iter = self.visit(node.iter) + node.target = self.visit(node.target) node.body = self.nested_visit(node.body) node.orelse = self.nested_visit(node.orelse) - return self.generic_visit_less(node, 'body', 'orelse', 'iter') + return node def visit_While(self, node): + node.test = self.visit(node.test) node.body = self.nested_visit(node.body) node.orelse = self.nested_visit(node.orelse) - return self.generic_visit_less(node, 'body', 'orelse') + return node def visit_If(self, node): node.test = self.visit(node.test) node.body = self.nested_visit(node.body) node.orelse = self.nested_visit(node.orelse) - return self.generic_visit_less(node, 'body', 'orelse', 'test') + return node def visit_With(self, node): + node.items = self.nested_visit(node.items, set_conditional_exec=False) node.body = self.nested_visit(node.body, set_conditional_exec=False) - return self.generic_visit_less(node, 'body') + return node def visit_AsyncWith(self, node): + node.items = self.nested_visit(node.items, set_conditional_exec=False) node.body = self.nested_visit(node.body, set_conditional_exec=False) - return self.generic_visit_less(node, 'body') + return node def visit_Try(self, node): node.body = self.nested_visit(node.body) + node.handlers = self.nested_visit(node.handlers) node.orelse = self.nested_visit(node.orelse) node.finalbody = self.nested_visit(node.finalbody, set_conditional_exec=False) - return self.generic_visit_less(node, 'body', 'orelse', 'finalbody') + return node def visit_Module(self, node): node.body = self.nested_visit(node.body, set_conditional_exec=False) diff --git a/pragma/inline.py b/pragma/inline.py index 8f3ac3b..f1b0c69 100644 --- a/pragma/inline.py +++ b/pragma/inline.py @@ -46,6 +46,13 @@ # -- col_offset is the byte offset in the utf8 string the parser uses # attributes (int lineno, int col_offset) + +class _PRAGMA_INLINE_RETURN(BaseException): + def __init__(self, val=None): + super().__init__() + self.return_val = val + + DICT_FMT = "_{fname}_{n}" @@ -58,7 +65,7 @@ def make_name(fname, var, n, ctx=ast.Load): :param var: Argument name :type var: str :param ctx: Context of this name (LOAD or STORE) - :type ctx: Load|Store + :type ctx: type :param n: The number to append to this name (to allow for finite recursion) :type n: int :param fmt: Name format (if not stored in a dictionary) @@ -74,36 +81,48 @@ def make_name(fname, var, n, ctx=ast.Load): class _InlineBodyTransformer(TrackedContextTransformer): def __init__(self, func_name, param_names, n): self.func_name = func_name - # print("Func {} takes parameters {}".format(func_name, param_names)) - self.param_names = param_names - self.in_break_block = False + self.param_names = list(param_names) + self.local_names = set(param_names) + self.nonlocal_names = set() + self.has_global_catch = False self.n = n self.had_return = False self.had_yield = False super().__init__() + def __setitem__(self, key, value): + self.local_names.add(key) + super().__setitem__(key, value) + + def visit_Global(self, node): + self.nonlocal_names |= node.names + self.local_names -= node.names + return self.generic_visit(node) + + def visit_Nonlocal(self, node): + self.nonlocal_names |= node.names + self.local_names -= node.names + return self.generic_visit(node) + def visit_Name(self, node): - # Check if this is a parameter, and hasn't had another value assigned to it - if node.id in self.param_names: - # print("Found parameter reference {}".format(node.id)) - if node.id not in self.ctxt: - # If so, get its value from the argument dictionary - return make_name(self.func_name, node.id, self.n, ctx=type(getattr(node, 'ctx', ast.Load()))) - else: - # print("But it's been overwritten to {} = {}".format(node.id, self.ctxt[node.id])) - pass + if isinstance(node.ctx, ast.Store) and node.id not in self.nonlocal_names: + self.local_names.add(node.id) + + log.debug("Is '{}' ({}) a local variable? locals={}, nonlocals={}".format(node.id, type(node.ctx), self.local_names, self.nonlocal_names)) + if node.id in self.local_names: + return make_name(self.func_name, node.id, self.n, ctx=type(getattr(node, 'ctx', ast.Load()))) return node def visit_Return(self, node): - if self.in_break_block: - raise NotImplementedError("miniutils.pragma.inline cannot handle returns from within a loop") - result = [] - if node.value: - result.append(ast.Assign(targets=[make_name(self.func_name, 'return', self.n, ctx=ast.Store)], - value=self.visit(node.value))) - result.append(ast.Break()) self.had_return = True - return result + return ast.Raise( + exc=ast.Call( + func=ast.Name(id=_PRAGMA_INLINE_RETURN.__name__, ctx=ast.Load()), + args=[self.visit(node.value)] if node.value is not None else [], + keywords=[] + ), + cause=None + ) def visit_Yield(self, node): self.had_yield = True @@ -123,19 +142,11 @@ def visit_YieldFrom(self, node): args=[self.visit(node.value)], keywords=[]) - def visit_For(self, node): - orig_in_break_block = self.in_break_block - self.in_break_block = True - res = super().visit_For(node) - self.in_break_block = orig_in_break_block - return res - - def visit_While(self, node): - orig_in_break_block = self.in_break_block - self.in_break_block = True - res = super().visit_While(node) - self.in_break_block = orig_in_break_block - return res + def visit_ExceptHandler(self, node): + node = self.generic_visit(node) + if node.type is None or issubclass(BaseException, self.resolve_name_or_attribute(node.type)): + self.has_global_catch = True + return node class InlineTransformer(TrackedContextTransformer): @@ -149,134 +160,149 @@ def __init__(self, *args, funs=None, max_depth=1, **kwargs): def visit_Call(self, node): """When we see a function call, insert the function body into the current code block, then replace the call with the return expression """ + node = self.generic_visit(node) node_fun = self.resolve_name_or_attribute(self.resolve_literal(node.func)) - for (fun, fname, fsig, fbody) in self.funs: - if fun != node_fun: - continue + try: + fun, fname, fsig, fbody = next(f for f in self.funs if f[0] == node_fun) + except StopIteration: + return node + + possible_dict_names = ((i, DICT_FMT.format(fname=fname, n=i)) for i in range(self.max_depth)) + possible_dict_names = ((i, name) for i, name in possible_dict_names if name not in self.ctxt) + try: + n, args_dict_name = next(possible_dict_names) + except StopIteration: + warnings.warn("Inline hit recursion limit, using normal function call") + return node - n = 0 - for i in range(self.max_depth): - args_dict_name = DICT_FMT.format(fname=fname, n=i) - n = i # This is redundant, but a bit clearer and safer than just referencing i later - if args_dict_name not in self.ctxt: - break - else: - warnings.warn("Inline hit recursion limit, using normal function call") - return node - - func_for_inlining = _InlineBodyTransformer(fname, fsig.parameters, n) - fbody = list(func_for_inlining.visit_many(copy.deepcopy(fbody))) - - # print(self.code_blocks) - cur_block = self.code_blocks[-1] - new_code = [] - - # Load arguments into their appropriate variables - args = node.args - flattened_args = [] - for a in args: - if isinstance(a, ast.Starred): - a = self.resolve_iterable(a.value) - if a: - flattened_args.extend(a) - else: - warnings.warn("Cannot inline function call that uses non-constant star args") - return node - else: - flattened_args.append(a) - - keywords = [(kw.arg, kw.value) for kw in node.keywords if kw.arg is not None] - kw_dict = [kw.value for kw in node.keywords if kw.arg is None] - kw_dict = kw_dict[0] if kw_dict else None - - bound_args = fsig.bind(*flattened_args, **odict(keywords)) - bound_args.apply_defaults() - - # Create args dictionary - final_args = [] - final_kwargs = [] - - for arg_name, arg_value in bound_args.arguments.items(): - if isinstance(arg_value, tuple): - arg_value = ast.Tuple(elts=list(arg_value), ctx=ast.Load()) - elif isinstance(arg_value, dict): - keys, values = zip(*list(arg_value.items())) - keys = [ast.Str(k) for k in keys] - values = list(values) - arg_value = ast.Dict(keys=keys, values=values) - # fun_name['param_name'] = param_value - final_kwargs.append((arg_name, arg_value)) - - if kw_dict: - final_kwargs.append((None, kw_dict)) - - if func_for_inlining.had_yield: - final_args.append(ast.List(elts=[ast.Tuple(elts=[ast.Str('yield'), ast.List(elts=[], ctx=ast.Load())], - ctx=ast.Load())], - ctx=ast.Load())) - - # fun_name = {} - dict_call = ast.Call( - func=ast.Name(id='dict', ctx=ast.Load()), - args=final_args, - keywords=[ast.keyword(arg=name, value=val) for name, val in final_kwargs] - ) - new_code.append(ast.Assign( - targets=[ast.Name(id=args_dict_name, ctx=ast.Store())], - value=dict_call - )) + func_for_inlining = _InlineBodyTransformer(fname, fsig.parameters, n) + fbody = list(func_for_inlining.visit_many(copy.deepcopy(fbody))) + + if func_for_inlining.has_global_catch: + warnings.warn("Unable to inline function with an unbound except statement") + return node - # Process assignments before resolving body - cur_block.extend(self.visit_many(new_code)) - - # Inline function code - new_body = list(self.visit_many(fbody)) - - # cur_block.append(self.visit(ast.For(target=ast.Name(id='____', ctx=ast.Store()), - # iter=ast.List(elts=[ast.NameConstant(None)], ctx=ast.Load()), - # body=new_body, - # orelse=[]))) - cur_block.append(ast.For(target=ast.Name(id='____', ctx=ast.Store()), - iter=ast.List(elts=[ast.NameConstant(None)], ctx=ast.Load()), - body=new_body, - orelse=[])) - - # fun_name['return'] - if func_for_inlining.had_yield or func_for_inlining.had_return: - for j in range(100000): - output_name = DICT_FMT.format(fname=fname + '_return', n=j) - if output_name not in self.ctxt: - break + # print(self.code_blocks) + cur_block = self.code_blocks[-1] + new_code = [] + + # Load arguments into their appropriate variables + args = node.args + flattened_args = [] + for a in args: + if isinstance(a, ast.Starred): + a = self.resolve_iterable(a.value) + if a: + flattened_args.extend(a) else: - raise RuntimeError("Function {} called and returned too many times during inlining, not able to " - "put the return value into a uniquely named variable".format(fname)) - - if func_for_inlining.had_yield: - cur_block.append(self.visit(ast.Assign(targets=[ast.Name(id=output_name, ctx=ast.Store())], - value=make_name(fname, 'yield', n)))) - elif func_for_inlining.had_return: - get_call = ast.Call( - func=ast.Attribute( - value=ast.Name(id=args_dict_name, ctx=ast.Load()), - attr='get', - ctx=ast.Load()), - args=[ast.Str('return'), ast.NameConstant(None)], - keywords=[] - ) - cur_block.append(self.visit(ast.Assign(targets=[ast.Name(id=output_name, ctx=ast.Store())], - value=get_call))) - - return_node = ast.Name(id=output_name, ctx=ast.Load()) + warnings.warn("Cannot inline function call that uses non-constant star args") + return node else: - return_node = ast.NameConstant(None) - - cur_block.append(self.visit(ast.Delete(targets=[ast.Name(id=args_dict_name, ctx=ast.Del())]))) - return return_node - + flattened_args.append(a) + + keywords = [(kw.arg, kw.value) for kw in node.keywords if kw.arg is not None] + kw_dict = [kw.value for kw in node.keywords if kw.arg is None] + kw_dict = kw_dict[0] if kw_dict else None + + bound_args = fsig.bind(*flattened_args, **odict(keywords)) + bound_args.apply_defaults() + + # Create args dictionary + final_args = [] + final_kwargs = [] + + for arg_name, arg_value in bound_args.arguments.items(): + if isinstance(arg_value, tuple): + arg_value = ast.Tuple(elts=list(arg_value), ctx=ast.Load()) + elif isinstance(arg_value, dict): + keys, values = zip(*list(arg_value.items())) + keys = [ast.Str(k) for k in keys] + values = list(values) + arg_value = ast.Dict(keys=keys, values=values) + # fun_name['param_name'] = param_value + final_kwargs.append((arg_name, arg_value)) + + if kw_dict: + final_kwargs.append((None, kw_dict)) + + if func_for_inlining.had_yield: + final_args.append(ast.List(elts=[ast.Tuple(elts=[ast.Str('yield'), ast.List(elts=[], ctx=ast.Load())], + ctx=ast.Load())], + ctx=ast.Load())) + + # fun_name = {} + dict_call = ast.Call( + func=ast.Name(id='dict', ctx=ast.Load()), + args=final_args, + keywords=[ast.keyword(arg=name, value=val) for name, val in final_kwargs] + ) + new_code.append(ast.Assign( + targets=[ast.Name(id=args_dict_name, ctx=ast.Store())], + value=dict_call + )) + + # Process assignments before resolving body + cur_block.extend(self.visit_many(new_code)) + + # Inline function code + new_body = list(self.visit_many(fbody)) + + for j in range(100000): + output_name = DICT_FMT.format(fname=fname + '_return', n=j) + if output_name not in self.ctxt: + break else: - return node + raise RuntimeError("Function {} called and returned too many times during inlining, not able to " + "put the return value into a uniquely named variable".format(fname)) + + return_node = ast.Name(id=output_name, ctx=ast.Load()) + + if func_for_inlining.had_yield: + afterwards_body = ast.Assign(targets=[ast.Name(id=output_name, ctx=ast.Store())], + value=make_name(fname, 'yield', n)) + elif func_for_inlining.had_return: + afterwards_body = ast.Assign(targets=[ast.Name(id=output_name, ctx=ast.Store())], + value=ast.Attribute( + value=ast.Name(id=output_name + "_exc", ctx=ast.Load()), + attr='return_val', + ctx=ast.Load() + )) + else: + return_node = ast.NameConstant(None) + afterwards_body = ast.Pass() + + self.visit(afterwards_body) + afterwards_body = [afterwards_body] if isinstance(afterwards_body, ast.AST) else afterwards_body + + if func_for_inlining.had_return: + cur_block.append(ast.Try( + body=new_body, + handlers=[ast.ExceptHandler( + type=ast.Name(id=_PRAGMA_INLINE_RETURN.__name__, ctx=ast.Load()), + name=output_name + "_exc", + body=afterwards_body + )], + orelse=afterwards_body if func_for_inlining.had_yield else [ + self.visit(ast.Assign(targets=[ast.Name(id=output_name, ctx=ast.Store())], + value=ast.NameConstant(None))) + ], + finalbody=[ + self.visit(ast.Delete(targets=[ast.Name(id=args_dict_name, ctx=ast.Del())])) + ] + )) + else: + cur_block.append(ast.Try( + body=new_body, + handlers=[], + orelse=[], + finalbody=(afterwards_body if not isinstance(afterwards_body[0], ast.Pass) else []) + [ + self.visit(ast.Delete(targets=[ast.Name(id=args_dict_name, ctx=ast.Del())])) + ] + )) + + return return_node # @magic_contract @@ -297,7 +323,10 @@ def inline(*funs_to_inline, max_depth=1, **kwargs): funs.append((fun_to_inline, fname, fsig, fbody)) + kwargs['function_globals'] = kwargs.get('function_globals', {}) + kwargs['function_globals'].update({_PRAGMA_INLINE_RETURN.__name__: _PRAGMA_INLINE_RETURN}) return make_function_transformer(InlineTransformer, 'inline', 'Inline the specified function within the decorated function', - funs=funs, max_depth=max_depth)(**kwargs) + funs=funs, + max_depth=max_depth)(**kwargs) diff --git a/tests/test_inline.py b/tests/test_inline.py index 6c4ed7c..8a86c4b 100644 --- a/tests/test_inline.py +++ b/tests/test_inline.py @@ -16,32 +16,14 @@ def f(y): result = ''' def f(y): _g_0 = dict(x=y + 3) - for ____ in [None]: - _g_0['return'] = _g_0['x'] ** 2 - break - _g_return_0 = _g_0.get('return', None) - del _g_0 - return _g_return_0 - ''' - - self.assertSourceEqual(f, result) - self.assertEqual(f(1), ((1 + 3) ** 2)) - - def test_basic_unroll(self): - def g(x): - return x**2 - - @pragma.unroll - @pragma.inline(g) - def f(y): - return g(y + 3) - - result = ''' - def f(y): - _g_0 = dict(x=y + 3) - _g_0['return'] = _g_0['x'] ** 2 - _g_return_0 = _g_0.get('return', None) - del _g_0 + try: + raise _PRAGMA_INLINE_RETURN(_g_0['x'] ** 2) + except _PRAGMA_INLINE_RETURN as _g_return_0_exc: + _g_return_0 = _g_return_0_exc.return_val + else: + _g_return_0 = None + finally: + del _g_0 return _g_return_0 ''' @@ -65,27 +47,29 @@ def f(): result1 = ''' def f(): _g_0 = dict(x=1, args=(2, 3, 4), y=5, kwargs={'z': 6, 'w': 7}) - for ____ in [None]: + try: print('X = {}'.format(_g_0['x'])) - for i, a in enumerate(_g_0['args']): - print('args[{}] = {}'.format(i, a)) + for _g_0['i'], _g_0['a'] in enumerate(_g_0['args']): + print('args[{}] = {}'.format(_g_0['i'], _g_0['a'])) print('Y = {}'.format(_g_0['y'])) - for k, v in _g_0['kwargs'].items(): - print('{} = {}'.format(k, v)) - del _g_0 + for _g_0['k'], _g_0['v'] in _g_0['kwargs'].items(): + print('{} = {}'.format(_g_0['k'], _g_0['v'])) + finally: + del _g_0 None ''' result2 = ''' def f(): _g_0 = dict(x=1, args=(2, 3, 4), y=5, kwargs={'w': 7, 'z': 6}) - for ____ in [None]: + try: print('X = {}'.format(_g_0['x'])) - for i, a in enumerate(_g_0['args']): - print('args[{}] = {}'.format(i, a)) + for _g_0['i'], _g_0['a'] in enumerate(_g_0['args']): + print('args[{}] = {}'.format(_g_0['i'], _g_0['a'])) print('Y = {}'.format(_g_0['y'])) - for k, v in _g_0['kwargs'].items(): - print('{} = {}'.format(k, v)) - del _g_0 + for _g_0['k'], _g_0['v'] in _g_0['kwargs'].items(): + print('{} = {}'.format(_g_0['k'], _g_0['v'])) + finally: + del _g_0 None ''' @@ -144,31 +128,19 @@ def f(y): if y <= 0: return 0 _g_0 = dict(x=y - 1) - for ____ in [None]: - _g_0['return'] = f(_g_0['x'] / 2) - break - _g_return_0 = _g_0.get('return', None) - del _g_0 + try: + raise _PRAGMA_INLINE_RETURN(f(_g_0['x'] / 2)) + except _PRAGMA_INLINE_RETURN as _g_return_0_exc: + _g_return_0 = _g_return_0_exc.return_val + else: + _g_return_0 = None + finally: + del _g_0 return _g_return_0 ''' self.assertSourceEqual(f_code, result) - f_unroll_code = pragma.unroll(pragma.inline(g)(f)) - - result_unroll = ''' - def f(y): - if y <= 0: - return 0 - _g_0 = dict(x=y - 1) - _g_0['return'] = f(_g_0['x'] / 2) - _g_return_0 = _g_0.get('return', None) - del _g_0 - return _g_return_0 - ''' - - self.assertSourceEqual(f_unroll_code, result_unroll) - f2_code = pragma.inline(f, g, f=f)(f) result2 = dedent(''' @@ -177,19 +149,24 @@ def f(y): return 0 _g_0 = dict(x=y - 1) _f_0 = dict(y=_g_0['x'] / 2) - for ____ in [None]: + try: if _f_0['y'] <= 0: - _f_0['return'] = 0 - break - _f_0['return'] = g(_f_0['y'] - 1) - break - _f_return_0 = _f_0.get('return', None) - del _f_0 - for ____ in [None]: - _g_0['return'] = _f_return_0 - break - _g_return_0 = _g_0.get('return', None) - del _g_0 + raise _PRAGMA_INLINE_RETURN(0) + raise _PRAGMA_INLINE_RETURN(g(_f_0['y'] - 1)) + except _PRAGMA_INLINE_RETURN as _f_return_0_exc: + _f_return_0 = _f_return_0_exc.return_val + else: + _f_return_0 = None + finally: + del _f_0 + try: + raise _PRAGMA_INLINE_RETURN(_f_return_0) + except _PRAGMA_INLINE_RETURN as _g_return_0_exc: + _g_return_0 = _g_return_0_exc.return_val + else: + _g_return_0 = None + finally: + del _g_0 return _g_return_0 ''') @@ -208,16 +185,18 @@ def f(x): result = ''' def f(x): _g_0 = dict([('yield', [])], y=x) - for ____ in [None]: - for i in range(_g_0['y']): - _g_0['yield'].append(i) + try: + for _g_0['i'] in range(_g_0['y']): + _g_0['yield'].append(_g_0['i']) _g_0['yield'].extend(range(_g_0['y'])) - _g_return_0 = _g_0['yield'] - del _g_0 + finally: + _g_return_0 = _g_0['yield'] + del _g_0 return sum(_g_return_0) ''' self.assertSourceEqual(f, result) + self.assertEqual(f(3), 6) def test_variable_starargs(self): def g(a, b, c): @@ -249,17 +228,28 @@ def f(x): result = ''' def f(x): _a_0 = dict(x=x) - _a_0['return'] = _a_0['x'] ** 2 - _a_return_0 = _a_0.get('return', None) - del _a_0 + try: + raise _PRAGMA_INLINE_RETURN(_a_0['x'] ** 2) + except _PRAGMA_INLINE_RETURN as _a_return_0_exc: + _a_return_0 = _a_return_0_exc.return_val + else: + _a_return_0 = None + finally: + del _a_0 _b_0 = dict(x=x) - _b_0['return'] = _b_0['x'] + 2 - _b_return_0 = _b_0.get('return', None) - del _b_0 + try: + raise _PRAGMA_INLINE_RETURN(_b_0['x'] + 2) + except _PRAGMA_INLINE_RETURN as _b_return_0_exc: + _b_return_0 = _b_return_0_exc.return_val + else: + _b_return_0 = None + finally: + del _b_0 return _a_return_0 + _b_return_0 ''' self.assertSourceEqual(f, result) + self.assertEqual(f(5), 32) def test_coverage(self): def g(y): @@ -290,15 +280,48 @@ def test_my_range(): result = ''' def test_my_range(): _my_range_0 = dict([('yield', [])], x=5) - i = 0 - while i < _my_range_0['x']: - _my_range_0['yield'].append(i) - i += 1 - _my_range_return_0 = _my_range_0['yield'] - del _my_range_0 + try: + _my_range_0['i'] = 0 + while _my_range_0['i'] < _my_range_0['x']: + _my_range_0['yield'].append(_my_range_0['i']) + _my_range_0['i'] += 1 + finally: + _my_range_return_0 = _my_range_0['yield'] + del _my_range_0 return list(_my_range_return_0) ''' self.assertSourceEqual(test_my_range, result) self.assertEqual(test_my_range(), [0, 1, 2, 3, 4]) + def test_return_inside_loop(self): + def g(x): + for i in range(x + 1): + if i == x: + return i + return None + + @pragma.inline(g) + def f(y): + return g(y + 2) + + result = ''' + def f(y): + _g_0 = dict(x=y + 2) + try: + for _g_0['i'] in range(_g_0['x'] + 1): + if _g_0['i'] == _g_0['x']: + raise _PRAGMA_INLINE_RETURN(_g_0['i']) + raise _PRAGMA_INLINE_RETURN(None) + except _PRAGMA_INLINE_RETURN as _g_return_0_exc: + _g_return_0 = _g_return_0_exc.return_val + else: + _g_return_0 = None + finally: + del _g_0 + return _g_return_0 + ''' + + self.assertSourceEqual(f, result) + self.assertEqual(f(3), 5) +