Skip to content

Commit

Permalink
Place new rewrite behind opt-in kwarg
Browse files Browse the repository at this point in the history
  • Loading branch information
odow committed Nov 14, 2022
1 parent 1b399fe commit 0cdedc7
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 137 deletions.
4 changes: 0 additions & 4 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,3 @@ end
```@autodocs
Modules = [MutableArithmetics]
```

```@autodocs
Modules = [MutableArithmetics.MutableArithmetics2]
```
2 changes: 1 addition & 1 deletion src/MutableArithmetics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,8 @@ function isequal_canonical(x::_SparseMat, y::_SparseMat)
end

include("rewrite.jl")
include("rewrite_generic.jl")
include("dispatch.jl")
include("new_rewrite.jl")

# Test that can be used to test an implementation of the interface
include("Test/Test.jl")
Expand Down
54 changes: 35 additions & 19 deletions src/rewrite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
# one at http://mozilla.org/MPL/2.0/.

"""
@rewrite(expr)
@rewrite(expr; assume_sums_are_linear = false)
Return the value of `expr` exploiting the mutability of the temporary
Return the value of `expr`, exploiting the mutability of the temporary
expressions created for the computation of the result.
## Examples
Expand All @@ -21,12 +21,23 @@ is rewritten into
MA.add_mul!!(
MA.add_mul!!(
MA.copy_if_mutable(x),
y, z),
u, v, w)
y,
z,
),
u,
v,
w,
)
```
"""
macro rewrite(expr)
return rewrite_and_return(expr)
macro rewrite(args...)
@assert 1 <= length(args) <= 2
if length(args) == 1
return rewrite_and_return(args[1]; assume_sums_are_linear = true)
end
@assert Meta.isexpr(args[2], :(=), 2) &&
args[2].args[1] == :assume_sums_are_linear
return rewrite_and_return(args[1]; assume_sums_are_linear = args[2].args[2])
end

struct Zero end
Expand Down Expand Up @@ -268,31 +279,36 @@ function _is_decomposable_with_factors(ex)
end

"""
rewrite(x)
rewrite(expr; assume_sums_are_linear::Bool = true) -> Tuple{Symbol,Expr}
Rewrites the expression `expr` to use mutable arithmetics.
Rewrite the expression `x` as specified in [`@rewrite`](@ref).
Returns a variable name as `Symbol` and the rewritten expression assigning the
value of the expression `x` to the variable.
Returns `(variable, code)` comprised of a `gensym`'d variable equivalent to
`expr` and the code necessary to create the variable.
"""
function rewrite(x)
function rewrite(x; kwargs...)
variable = gensym()
code = rewrite_and_return(x)
code = rewrite_and_return(x; kwargs...)
return variable, :($variable = $code)
end

"""
rewrite_and_return(x)
rewrite_and_return(expr; assume_sums_are_linear::Bool = true) -> Expr
Rewrite the expression `x` as specified in [`@rewrite`](@ref).
Rewrite the expression `expr` as specified in [`@rewrite`](@ref).
Return the rewritten expression returning the result.
"""
function rewrite_and_return(x)
output_variable, code = _rewrite(false, false, x, nothing, [], [])
# We need to use `let` because `rewrite(:(sum(i for i in 1:2))`
function rewrite_and_return(expr; assume_sums_are_linear::Bool = true)
if assume_sums_are_linear
root, stack = _rewrite(false, false, expr, nothing, [], [])
else
stack = quote end
root, _ = _rewrite_generic(stack, expr)
end
return quote
let
$code
$output_variable
$stack
$root
end
end
end
Expand Down
Loading

0 comments on commit 0cdedc7

Please sign in to comment.