Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
willow-ahrens committed Nov 27, 2024
1 parent 8902a88 commit 0be91a8
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 9 deletions.
21 changes: 20 additions & 1 deletion docs/src/docs/array_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,4 +144,23 @@ julia> @btime compute(sum(A * B * C), ctx=galley_scheduler());

By taking advantage of the fact that C is highly sparse, Galley can better structure the computation. In the matrix chain multiplication,
it always starts with the C,B matmul before multiplying with A. In the summation, it takes advantage of distributivity to pushing the reduction
down to the inputs. It first sums over A and C, then multiplies those vectors with B.
down to the inputs. It first sums over A and C, then multiplies those vectors with B.

Because Galley adapts to the sparsity patterns of the first input tensor, it can
be useful to distinguish between different uses of the same function using the
`tag` keyword argument to `compute` or `fuse`. For example, we may wish to
distinguish one spmv from another, as follows:

```jldoctest example2
julia> A = rand(1000, 1000); B = rand(1000, 1000); C = fsprand(1000, 1000, 0.0001);
julia> fused((A, B, C) -> C .* (A * B), tag=:very_sparse_sddmm);
julia> C = fsprand(1000, 1000, 0.9);
julia> fused((A, B, C) -> C .* (A * B), tag=:very_dense_sddmm);
```

By distinguishing between the two uses of the same function, Galley can make
better decisions about how to optimize each computation separately.
16 changes: 11 additions & 5 deletions src/interface/lazy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -479,9 +479,12 @@ default_scheduler(;verbose=false) = LogicExecutor(DefaultLogicOptimizer(LogicCom
"""
fused(f, args...; kwargs...)
This function decorator modifies `f` to fuse the contained array
operations and optimize the resulting program. The function must return a single
array or tuple of arrays. `kwargs` are passed to [`compute`](@ref)
This function decorator modifies `f` to fuse the contained array operations and
optimize the resulting program. The function must return a single array or tuple
of arrays. Some keyword arguments can be passed to control the execution of the
program:
- `verbose=false`: Print the generated code before execution
- `tag=:global`: A tag to distinguish between different classes of inputs for the same program.
"""
function fused(f, args...; kwargs...)
compute(f(map(LazyTensor, args...)), kwargs...)
Expand Down Expand Up @@ -520,10 +523,13 @@ function with_scheduler(f, scheduler)
end

"""
compute(args..., ctx=default_scheduler()) -> Any
compute(args...; ctx=default_scheduler(), kwargs...) -> Any
Compute the value of a lazy tensor. The result is the argument itself, or a
tuple of arguments if multiple arguments are passed.
tuple of arguments if multiple arguments are passed. Some keyword arguments
can be passed to control the execution of the program:
- `verbose=false`: Print the generated code before execution
- `tag=:global`: A tag to distinguish between different classes of inputs for the same program.
"""
compute(args...; ctx=get_scheduler(), kwargs...) = compute_parse(set_options(ctx; kwargs...), map(lazy, args))
compute(arg; ctx=get_scheduler(), kwargs...) = compute_parse(set_options(ctx; kwargs...), (lazy(arg),))[1]
Expand Down
5 changes: 2 additions & 3 deletions src/scheduler/LogicExecutor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,8 @@ end
codes = Dict()
function (ctx::LogicExecutor)(prgm)
(f, code) = get!(codes, (ctx.ctx, ctx.tag, get_structure(prgm))) do
thunk = logic_executor_code(ctx.ctx, prgm)
(eval(thunk), thunk)
end
thunk = logic_executor_code(ctx.ctx, prgm)
(eval(thunk), thunk)
end
if ctx.verbose
println("Executing:")
Expand Down

0 comments on commit 0be91a8

Please sign in to comment.