From 0be91a8567e632a23b4a6bb61aef3bc6042b3ae0 Mon Sep 17 00:00:00 2001 From: Willow Ahrens Date: Tue, 26 Nov 2024 20:45:10 -0500 Subject: [PATCH] fix --- docs/src/docs/array_api.md | 21 ++++++++++++++++++++- src/interface/lazy.jl | 16 +++++++++++----- src/scheduler/LogicExecutor.jl | 5 ++--- 3 files changed, 33 insertions(+), 9 deletions(-) diff --git a/docs/src/docs/array_api.md b/docs/src/docs/array_api.md index 06b2091a0..bcd66f9e2 100644 --- a/docs/src/docs/array_api.md +++ b/docs/src/docs/array_api.md @@ -144,4 +144,23 @@ julia> @btime compute(sum(A * B * C), ctx=galley_scheduler()); By taking advantage of the fact that C is highly sparse, Galley can better structure the computation. In the matrix chain multiplication, it always starts with the C,B matmul before multiplying with A. In the summation, it takes advantage of distributivity to pushing the reduction -down to the inputs. It first sums over A and C, then multiplies those vectors with B. \ No newline at end of file +down to the inputs. It first sums over A and C, then multiplies those vectors with B. + +Because Galley adapts to the sparsity patterns of the first input tensor, it can +be useful to distinguish between different uses of the same function using the +`tag` keyword argument to `compute` or `fuse`. For example, we may wish to +distinguish one spmv from another, as follows: + +```jldoctest example2 +julia> A = rand(1000, 1000); B = rand(1000, 1000); C = fsprand(1000, 1000, 0.0001); + +julia> fused((A, B, C) -> C .* (A * B), tag=:very_sparse_sddmm); + +julia> C = fsprand(1000, 1000, 0.9); + +julia> fused((A, B, C) -> C .* (A * B), tag=:very_dense_sddmm); + +``` + +By distinguishing between the two uses of the same function, Galley can make +better decisions about how to optimize each computation separately. \ No newline at end of file diff --git a/src/interface/lazy.jl b/src/interface/lazy.jl index 578dc4a9e..7b9f051bf 100644 --- a/src/interface/lazy.jl +++ b/src/interface/lazy.jl @@ -479,9 +479,12 @@ default_scheduler(;verbose=false) = LogicExecutor(DefaultLogicOptimizer(LogicCom """ fused(f, args...; kwargs...) -This function decorator modifies `f` to fuse the contained array -operations and optimize the resulting program. The function must return a single -array or tuple of arrays. `kwargs` are passed to [`compute`](@ref) +This function decorator modifies `f` to fuse the contained array operations and +optimize the resulting program. The function must return a single array or tuple +of arrays. Some keyword arguments can be passed to control the execution of the +program: + - `verbose=false`: Print the generated code before execution + - `tag=:global`: A tag to distinguish between different classes of inputs for the same program. """ function fused(f, args...; kwargs...) compute(f(map(LazyTensor, args...)), kwargs...) @@ -520,10 +523,13 @@ function with_scheduler(f, scheduler) end """ - compute(args..., ctx=default_scheduler()) -> Any + compute(args...; ctx=default_scheduler(), kwargs...) -> Any Compute the value of a lazy tensor. The result is the argument itself, or a -tuple of arguments if multiple arguments are passed. +tuple of arguments if multiple arguments are passed. Some keyword arguments +can be passed to control the execution of the program: + - `verbose=false`: Print the generated code before execution + - `tag=:global`: A tag to distinguish between different classes of inputs for the same program. """ compute(args...; ctx=get_scheduler(), kwargs...) = compute_parse(set_options(ctx; kwargs...), map(lazy, args)) compute(arg; ctx=get_scheduler(), kwargs...) = compute_parse(set_options(ctx; kwargs...), (lazy(arg),))[1] diff --git a/src/scheduler/LogicExecutor.jl b/src/scheduler/LogicExecutor.jl index 22e1a3b2e..89a1d26a9 100644 --- a/src/scheduler/LogicExecutor.jl +++ b/src/scheduler/LogicExecutor.jl @@ -74,9 +74,8 @@ end codes = Dict() function (ctx::LogicExecutor)(prgm) (f, code) = get!(codes, (ctx.ctx, ctx.tag, get_structure(prgm))) do - thunk = logic_executor_code(ctx.ctx, prgm) - (eval(thunk), thunk) - end + thunk = logic_executor_code(ctx.ctx, prgm) + (eval(thunk), thunk) end if ctx.verbose println("Executing:")