diff --git a/src/Finch.jl b/src/Finch.jl index dffe22a88..e6fb58120 100644 --- a/src/Finch.jl +++ b/src/Finch.jl @@ -151,9 +151,9 @@ include("FinchLogic/FinchLogic.jl") using .FinchLogic include("scheduler/LogicCompiler.jl") +include("scheduler/LogicExecutor.jl") include("scheduler/LogicInterpreter.jl") include("scheduler/optimize.jl") -include("scheduler/compute.jl") include("interface/traits.jl") include("interface/abstractarrays.jl") diff --git a/src/interface/lazy.jl b/src/interface/lazy.jl index a8c70b1c3..ffbde5611 100644 --- a/src/interface/lazy.jl +++ b/src/interface/lazy.jl @@ -437,4 +437,24 @@ optimizer to use. """ function fused(f, args...; optimizer=DefaultOptimizer()) compute(f(map(LazyTensor, args...)), optimizer) +end + +default_scheduler = LogicExecutor(DefaultLogicOptimizer(LogicCompiler())) + +""" + compute(args..., ctx=default_scheduler) -> Any + +Compute the value of a lazy tensor. The result is the argument itself, or a +tuple of arguments if multiple arguments are passed. +""" +compute(args...; ctx=default_scheduler) = compute_parse(ctx, args) +compute(arg; ctx=default_scheduler) = compute_parse(ctx, (arg,))[1] +compute(args::Tuple; ctx=default_scheduler) = compute_parse(ctx, args) +function compute_parse(ctx, args::Tuple) + args = collect(args) + vars = map(arg -> alias(gensym(:A)), args) + bodies = map((arg, var) -> query(var, arg.data), args, vars) + prgm = plan(bodies, produces(vars)) + + return ctx(prgm) end \ No newline at end of file diff --git a/src/scheduler/LogicCompiler.jl b/src/scheduler/LogicCompiler.jl index 7675e88dd..372e17dcf 100644 --- a/src/scheduler/LogicCompiler.jl +++ b/src/scheduler/LogicCompiler.jl @@ -37,22 +37,9 @@ function get_structure( end end -""" - FinchCompiler - -The finch compiler is a simple compiler for finch logic programs. The interpreter is -only capable of executing programs of the form: - REORDER := reorder(relabel(ALIAS, FIELD...), FIELD...) - ACCESS := reorder(relabel(ALIAS, idxs_1::FIELD...), idxs_2::FIELD...) where issubsequence(idxs_1, idxs_2) - POINTWISE := ACCESS | mapjoin(IMMEDIATE, POINTWISE...) | reorder(IMMEDIATE, FIELD...) | IMMEDIATE - MAPREDUCE := POINTWISE | aggregate(IMMEDIATE, IMMEDIATE, POINTWISE, FIELD...) - TABLE := table(IMMEDIATE | DEFERRED, FIELD...) -COMPUTE_QUERY := query(ALIAS, reformat(IMMEDIATE, arg::(REORDER | MAPREDUCE))) - INPUT_QUERY := query(ALIAS, TABLE) - STEP := COMPUTE_QUERY | INPUT_QUERY | produces(ALIAS...) - ROOT := PLAN(STEP...) -""" -struct FinchCompiler end +@kwdef struct LogicLowerer + mode = :fast +end function finch_pointwise_logic_to_code(ex) if @capture ex mapjoin(~op, ~args...) @@ -88,7 +75,7 @@ function logic_constant_type(node) end end -function (ctx::FinchCompiler)(ex) +function (ctx::LogicLowerer)(ex) if @capture ex query(~lhs::isalias, table(~tns, ~idxs...)) :($(lhs.name) = $(compile_logic_constant(tns))) elseif @capture ex query(~lhs::isalias, reformat(~tns, reorder(relabel(~arg::isalias, ~idxs_1...), ~idxs_2...))) @@ -103,7 +90,7 @@ function (ctx::FinchCompiler)(ex) end quote $(lhs.name) = $(compile_logic_constant(tns)) - @finch begin + @finch mode = $(QuoteNode(ctx.mode)) begin $(lhs.name) .= $(default(logic_constant_type(tns))) $body return $(lhs.name) @@ -124,7 +111,7 @@ function (ctx::FinchCompiler)(ex) end quote $(lhs.name) = $(compile_logic_constant(tns)) - @finch begin + @finch mode = $(QuoteNode(ctx.mode)) begin $(lhs.name) .= $(default(logic_constant_type(tns))) $body return $(lhs.name) @@ -139,61 +126,33 @@ function (ctx::FinchCompiler)(ex) end end -""" - defer_tables(root::LogicNode) - -Replace immediate tensors with deferred expressions assuming the original program structure -is given as input to the program. -""" -function defer_tables(ex, node::LogicNode) - if @capture node table(~tns::isimmediate, ~idxs...) - table(deferred(:($ex.tns.val), typeof(tns.val)), map(enumerate(node.idxs)) do (i, idx) - defer_tables(:($ex.idxs[$i]), idx) - end) - elseif istree(node) - similarterm(node, operation(node), map(enumerate(node.children)) do (i, child) - defer_tables(:($ex.children[$i]), child) - end) - else - node - end -end """ - cache_deferred(ctx, root::LogicNode, seen) + LogicCompiler -Replace deferred expressions with simpler expressions, and cache their evaluation in the preamble. +The LogicCompiler is a simple compiler for finch logic programs. The interpreter is +only capable of executing programs of the form: + REORDER := reorder(relabel(ALIAS, FIELD...), FIELD...) + ACCESS := reorder(relabel(ALIAS, idxs_1::FIELD...), idxs_2::FIELD...) where issubsequence(idxs_1, idxs_2) + POINTWISE := ACCESS | mapjoin(IMMEDIATE, POINTWISE...) | reorder(IMMEDIATE, FIELD...) | IMMEDIATE + MAPREDUCE := POINTWISE | aggregate(IMMEDIATE, IMMEDIATE, POINTWISE, FIELD...) + TABLE := table(IMMEDIATE | DEFERRED, FIELD...) +COMPUTE_QUERY := query(ALIAS, reformat(IMMEDIATE, arg::(REORDER | MAPREDUCE))) + INPUT_QUERY := query(ALIAS, TABLE) + STEP := COMPUTE_QUERY | INPUT_QUERY | produces(ALIAS...) + ROOT := PLAN(STEP...) """ -function cache_deferred!(ctx, root::LogicNode) - seen::Dict{Any, LogicNode} = Dict{Any, LogicNode}() - return Rewrite(Postwalk(node -> if isdeferred(node) - get!(seen, node.val) do - var = freshen(ctx, :V) - push!(ctx.preamble, :($var = $(node.ex)::$(node.type))) - deferred(var, node.type) - end - end))(root) +@kwdef struct LogicCompiler + mode = :fast end -function compile(prgm::LogicNode) - ctx = JuliaContext() - freshen(ctx, :prgm) - code = contain(ctx) do ctx_2 - prgm = defer_tables(:prgm, prgm) - prgm = cache_deferred!(ctx_2, prgm) - prgm = optimize(prgm) - prgm = format_queries(prgm, true) - FinchCompiler()(prgm) - end - code = pretty(code) - fname = gensym(:compute) - return :(function $fname(prgm) - $code - end) |> striplines +function (ctx::LogicCompiler)(prgm::LogicNode) + prgm = format_queries(prgm, true) + LogicLowerer(mode=ctx.mode)(prgm) end codes = Dict() -function compute_impl(prgm, ::FinchCompiler) +function compute_impl(prgm, ::LogicLowerer) f = get!(codes, get_structure(prgm)) do eval(compile(prgm)) end diff --git a/src/scheduler/LogicExecutor.jl b/src/scheduler/LogicExecutor.jl new file mode 100644 index 000000000..781b77b9c --- /dev/null +++ b/src/scheduler/LogicExecutor.jl @@ -0,0 +1,82 @@ +""" + defer_tables(root::LogicNode) + +Replace immediate tensors with deferred expressions assuming the original program structure +is given as input to the program. +""" +function defer_tables(ex, node::LogicNode) + if @capture node table(~tns::isimmediate, ~idxs...) + table(deferred(:($ex.tns.val), typeof(tns.val)), map(enumerate(node.idxs)) do (i, idx) + defer_tables(:($ex.idxs[$i]), idx) + end) + elseif istree(node) + similarterm(node, operation(node), map(enumerate(node.children)) do (i, child) + defer_tables(:($ex.children[$i]), child) + end) + else + node + end +end + +""" + cache_deferred(ctx, root::LogicNode, seen) + +Replace deferred expressions with simpler expressions, and cache their evaluation in the preamble. +""" +function cache_deferred!(ctx, root::LogicNode) + seen::Dict{Any, LogicNode} = Dict{Any, LogicNode}() + return Rewrite(Postwalk(node -> if isdeferred(node) + get!(seen, node.val) do + var = freshen(ctx, :V) + push!(ctx.preamble, :($var = $(node.ex)::$(node.type))) + deferred(var, node.type) + end + end))(root) +end + +function logic_executor_code(ctx, prgm) + ctx_2 = JuliaContext() + freshen(ctx_2, :prgm) + code = contain(ctx_2) do ctx_3 + prgm = defer_tables(:prgm, prgm) + prgm = cache_deferred!(ctx_3, prgm) + ctx(prgm) + end + code = pretty(code) + fname = gensym(:compute) + return :(function $fname(prgm) + $code + end) |> striplines +end + +""" + LogicExecutor(ctx) + +Executes a logic program by compiling it with the given compiler `ctx`. Compiled +codes are cached, and are only compiled once for each program with the same +structure. +""" +struct LogicExecutor + ctx +end + +codes = Dict() +function (ctx::LogicExecutor)(prgm) + f = get!(codes, get_structure(prgm)) do + eval(logic_executor_code(ctx.ctx, prgm)) + end + return Base.invokelatest(f, prgm) +end + +""" + LogicExecutorCode(ctx) + +Return the code that would normally be used by the LogicExecutor to run a program. +""" +struct LogicExecutorCode + ctx +end + +function (ctx::LogicExecutorCode)(prgm) + return logic_executor_code(ctx.ctx, prgm) +end \ No newline at end of file diff --git a/src/scheduler/LogicInterpreter.jl b/src/scheduler/LogicInterpreter.jl index 14f1c182e..29c4cdc25 100644 --- a/src/scheduler/LogicInterpreter.jl +++ b/src/scheduler/LogicInterpreter.jl @@ -1,24 +1,3 @@ -""" - FinchInterpreter - -The finch interpreter is a simple interpreter for finch logic programs. The interpreter is -only capable of executing programs of the form: - REORDER := reorder(relabel(ALIAS, FIELD...), FIELD...) - ACCESS := reorder(relabel(ALIAS, idxs_1::FIELD...), idxs_2::FIELD...) where issubsequence(idxs_1, idxs_2) - POINTWISE := ACCESS | mapjoin(IMMEDIATE, POINTWISE...) | reorder(IMMEDIATE, FIELD...) | IMMEDIATE - MAPREDUCE := POINTWISE | aggregate(IMMEDIATE, IMMEDIATE, POINTWISE, FIELD...) - TABLE := table(IMMEDIATE, FIELD...) -COMPUTE_QUERY := query(ALIAS, reformat(IMMEDIATE, arg::(REORDER | MAPREDUCE))) - INPUT_QUERY := query(ALIAS, TABLE) - STEP := COMPUTE_QUERY | INPUT_QUERY | produces(ALIAS...) - ROOT := PLAN(STEP...) -""" -struct FinchInterpreter - scope::Dict -end - -FinchInterpreter() = FinchInterpreter(Dict()) - using Finch.FinchNotation: block_instance, declare_instance, call_instance, loop_instance, index_instance, variable_instance, tag_instance, access_instance, assign_instance, literal_instance, yieldbind_instance function finch_pointwise_logic_to_program(scope, ex) @@ -38,7 +17,13 @@ function finch_pointwise_logic_to_program(scope, ex) end end -function (ctx::FinchInterpreter)(ex) +@kwdef struct LogicMachine + scope = Dict{Any, Any}() + verbose = false + mode = :fast +end + +function (ctx::LogicMachine)(ex) if ex.kind === alias ex.scope[ex] elseif @capture ex query(~lhs, ~rhs) @@ -57,8 +42,11 @@ function (ctx::FinchInterpreter)(ex) body = loop_instance(idx, dimless, body) end body = block_instance(declare_instance(res, literal_instance(default(tns.val))), body, yieldbind_instance(res)) - #display(body) # wow it's really satisfying to uncomment this and type finch ops at the repl. - execute(body).res + if ctx.verbose + print("Running: ") + display(body) + end + execute(body, mode = ctx.mode).res elseif @capture ex reformat(~tns, mapjoin(~args...)) z = default(tns.val) ctx(reformat(tns, aggregate(initwrite(z), immediate(z), mapjoin(args...)))) @@ -73,8 +61,11 @@ function (ctx::FinchInterpreter)(ex) body = loop_instance(idx, dimless, body) end body = block_instance(declare_instance(res, literal_instance(default(tns.val))), body, yieldbind_instance(res)) - #display(body) # wow it's really satisfying to uncomment this and type finch ops at the repl. - execute(body).res + if ctx.verbose + print("Running: ") + display(body) + end + execute(body, mode = ctx.mode).res elseif @capture ex produces(~args...) return map(arg -> ctx.scope[arg], args) elseif @capture ex plan(~head) @@ -87,14 +78,27 @@ function (ctx::FinchInterpreter)(ex) end end -function normalize_names(ex) - spc = Namespace() - scope = Dict() - normname(sym) = get!(scope, sym) do - if isgensym(sym) - sym = gensymname(sym) - end - freshen(spc, sym) - end - Rewrite(Postwalk(@rule ~a::isalias => alias(normname(a.name))))(ex) +""" + LogicInterpreter(scope = Dict(), verbose = false, mode = :fast) + +The LogicInterpreter is a simple interpreter for finch logic programs. The interpreter is +only capable of executing programs of the form: + REORDER := reorder(relabel(ALIAS, FIELD...), FIELD...) + ACCESS := reorder(relabel(ALIAS, idxs_1::FIELD...), idxs_2::FIELD...) where issubsequence(idxs_1, idxs_2) + POINTWISE := ACCESS | mapjoin(IMMEDIATE, POINTWISE...) | reorder(IMMEDIATE, FIELD...) | IMMEDIATE + MAPREDUCE := POINTWISE | aggregate(IMMEDIATE, IMMEDIATE, POINTWISE, FIELD...) + TABLE := table(IMMEDIATE, FIELD...) +COMPUTE_QUERY := query(ALIAS, reformat(IMMEDIATE, arg::(REORDER | MAPREDUCE))) + INPUT_QUERY := query(ALIAS, TABLE) + STEP := COMPUTE_QUERY | INPUT_QUERY | produces(ALIAS...) + ROOT := PLAN(STEP...) +""" +@kwdef struct LogicInterpreter + verbose = false + mode = :fast +end + +function (ctx::LogicInterpreter)(prgm) + prgm = format_queries(prgm) + LogicMachine(verbose = ctx.verbose, mode = ctx.mode)(prgm) end \ No newline at end of file diff --git a/src/scheduler/compute.jl b/src/scheduler/compute.jl deleted file mode 100644 index 9ff6d9c5a..000000000 --- a/src/scheduler/compute.jl +++ /dev/null @@ -1,33 +0,0 @@ -struct DefaultOptimizer - ctx -end - -default_optimizer = DefaultOptimizer(FinchCompiler()) - -""" - compute(args..., ctx=default_optimizer) -> Any - -Compute the value of a lazy tensor. The result is the argument itself, or a -tuple of arguments if multiple arguments are passed. -""" -compute(args...; ctx=default_optimizer) = compute_parse(args, ctx) -compute(arg; ctx=default_optimizer) = compute_parse((arg,), ctx)[1] -compute(args::Tuple; ctx=default_optimizer) = compute_parse(args, ctx) -function compute_parse(args::Tuple, ctx) - args = collect(args) - vars = map(arg -> alias(gensym(:A)), args) - bodies = map((arg, var) -> query(var, arg.data), args, vars) - prgm = plan(bodies, produces(vars)) - - return compute_impl(prgm, ctx) -end - -function compute_impl(prgm, ctx::DefaultOptimizer) - compute_impl(prgm, ctx.ctx) -end - -function compute_impl(prgm, ctx::FinchInterpreter) - prgm = optimize(prgm) - prgm = format_queries(prgm) - ctx(prgm) -end diff --git a/src/scheduler/optimize.jl b/src/scheduler/optimize.jl index 7e5ea6c6f..a5d37aea1 100644 --- a/src/scheduler/optimize.jl +++ b/src/scheduler/optimize.jl @@ -378,6 +378,18 @@ function propagate_map_queries(root) ])))(root) end +function normalize_names(ex) + spc = Namespace() + scope = Dict() + normname(sym) = get!(scope, sym) do + if isgensym(sym) + sym = gensymname(sym) + end + freshen(spc, sym) + end + Rewrite(Postwalk(@rule ~a::isalias => alias(normname(a.name))))(ex) +end + function optimize(prgm) #deduplicate and lift inline subqueries to regular queries prgm = lift_subqueries(prgm) @@ -395,9 +407,6 @@ function optimize(prgm) #I shouldn't use gensym but I do, so this cleans up the names prgm = pretty_labels(prgm) - #@info "split" - #display(prgm) - #These steps fuse copy, permutation, and mapjoin statements #into later expressions. #Only reformat statements preserve intermediate breaks in computation @@ -405,35 +414,36 @@ function optimize(prgm) prgm = propagate_transpose_queries(prgm) prgm = propagate_map_queries(prgm) - #@info "fused" - #display(prgm) - #These steps assign a global loop order to each statement. prgm = propagate_fields(prgm) - #@info "propagate_fields" - #display(prgm) - prgm = push_fields(prgm) prgm = lift_fields(prgm) prgm = push_fields(prgm) - #@info "loops ordered" - #display(prgm) - #After we have a global loop order, we concordize the program prgm = concordize(prgm) - #@info "concordized" - #display(prgm) - #Add reformat statements where there aren't any prgm = propagate_into_reformats(prgm) prgm = propagate_copy_queries(prgm) + #Normalize names for caching prgm = normalize_names(prgm) +end - #@info "formatted" - #display(prgm) +""" + DefaultLogicOptimizer(ctx) + +The default optimizer for finch logic programs. Optimizes to a structure +suitable for the LogicCompiler or LogicInterpreter, then calls `ctx` on the +resulting program. +""" +struct DefaultLogicOptimizer + ctx +end +function (ctx::DefaultLogicOptimizer)(prgm) + prgm = optimize(prgm) + ctx.ctx(prgm) end \ No newline at end of file