Skip to content

Commit

Permalink
Merge pull request #522 from willow-ahrens/wma/cleanup-caching2
Browse files Browse the repository at this point in the history
Wma/cleanup caching2
  • Loading branch information
willow-ahrens authored May 5, 2024
2 parents 9c110ec + d221623 commit c7cd3e9
Show file tree
Hide file tree
Showing 7 changed files with 194 additions and 152 deletions.
2 changes: 1 addition & 1 deletion src/Finch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,9 @@ include("FinchLogic/FinchLogic.jl")
using .FinchLogic

include("scheduler/LogicCompiler.jl")
include("scheduler/LogicExecutor.jl")
include("scheduler/LogicInterpreter.jl")
include("scheduler/optimize.jl")
include("scheduler/compute.jl")

include("interface/traits.jl")
include("interface/abstractarrays.jl")
Expand Down
20 changes: 20 additions & 0 deletions src/interface/lazy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -437,4 +437,24 @@ optimizer to use.
"""
function fused(f, args...; optimizer=DefaultOptimizer())
compute(f(map(LazyTensor, args...)), optimizer)
end

default_scheduler = LogicExecutor(DefaultLogicOptimizer(LogicCompiler()))

"""
compute(args..., ctx=default_scheduler) -> Any
Compute the value of a lazy tensor. The result is the argument itself, or a
tuple of arguments if multiple arguments are passed.
"""
compute(args...; ctx=default_scheduler) = compute_parse(ctx, args)
compute(arg; ctx=default_scheduler) = compute_parse(ctx, (arg,))[1]
compute(args::Tuple; ctx=default_scheduler) = compute_parse(ctx, args)
function compute_parse(ctx, args::Tuple)
args = collect(args)
vars = map(arg -> alias(gensym(:A)), args)
bodies = map((arg, var) -> query(var, arg.data), args, vars)
prgm = plan(bodies, produces(vars))

return ctx(prgm)
end
89 changes: 24 additions & 65 deletions src/scheduler/LogicCompiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,22 +37,9 @@ function get_structure(
end
end

"""
FinchCompiler
The finch compiler is a simple compiler for finch logic programs. The interpreter is
only capable of executing programs of the form:
REORDER := reorder(relabel(ALIAS, FIELD...), FIELD...)
ACCESS := reorder(relabel(ALIAS, idxs_1::FIELD...), idxs_2::FIELD...) where issubsequence(idxs_1, idxs_2)
POINTWISE := ACCESS | mapjoin(IMMEDIATE, POINTWISE...) | reorder(IMMEDIATE, FIELD...) | IMMEDIATE
MAPREDUCE := POINTWISE | aggregate(IMMEDIATE, IMMEDIATE, POINTWISE, FIELD...)
TABLE := table(IMMEDIATE | DEFERRED, FIELD...)
COMPUTE_QUERY := query(ALIAS, reformat(IMMEDIATE, arg::(REORDER | MAPREDUCE)))
INPUT_QUERY := query(ALIAS, TABLE)
STEP := COMPUTE_QUERY | INPUT_QUERY | produces(ALIAS...)
ROOT := PLAN(STEP...)
"""
struct FinchCompiler end
@kwdef struct LogicLowerer
mode = :fast
end

function finch_pointwise_logic_to_code(ex)
if @capture ex mapjoin(~op, ~args...)
Expand Down Expand Up @@ -88,7 +75,7 @@ function logic_constant_type(node)
end
end

function (ctx::FinchCompiler)(ex)
function (ctx::LogicLowerer)(ex)
if @capture ex query(~lhs::isalias, table(~tns, ~idxs...))
:($(lhs.name) = $(compile_logic_constant(tns)))
elseif @capture ex query(~lhs::isalias, reformat(~tns, reorder(relabel(~arg::isalias, ~idxs_1...), ~idxs_2...)))
Expand All @@ -103,7 +90,7 @@ function (ctx::FinchCompiler)(ex)
end
quote
$(lhs.name) = $(compile_logic_constant(tns))
@finch begin
@finch mode = $(QuoteNode(ctx.mode)) begin
$(lhs.name) .= $(default(logic_constant_type(tns)))
$body
return $(lhs.name)
Expand All @@ -124,7 +111,7 @@ function (ctx::FinchCompiler)(ex)
end
quote
$(lhs.name) = $(compile_logic_constant(tns))
@finch begin
@finch mode = $(QuoteNode(ctx.mode)) begin
$(lhs.name) .= $(default(logic_constant_type(tns)))
$body
return $(lhs.name)
Expand All @@ -139,61 +126,33 @@ function (ctx::FinchCompiler)(ex)
end
end

"""
defer_tables(root::LogicNode)
Replace immediate tensors with deferred expressions assuming the original program structure
is given as input to the program.
"""
function defer_tables(ex, node::LogicNode)
if @capture node table(~tns::isimmediate, ~idxs...)
table(deferred(:($ex.tns.val), typeof(tns.val)), map(enumerate(node.idxs)) do (i, idx)
defer_tables(:($ex.idxs[$i]), idx)
end)
elseif istree(node)
similarterm(node, operation(node), map(enumerate(node.children)) do (i, child)
defer_tables(:($ex.children[$i]), child)
end)
else
node
end
end

"""
cache_deferred(ctx, root::LogicNode, seen)
LogicCompiler
Replace deferred expressions with simpler expressions, and cache their evaluation in the preamble.
The LogicCompiler is a simple compiler for finch logic programs. The interpreter is
only capable of executing programs of the form:
REORDER := reorder(relabel(ALIAS, FIELD...), FIELD...)
ACCESS := reorder(relabel(ALIAS, idxs_1::FIELD...), idxs_2::FIELD...) where issubsequence(idxs_1, idxs_2)
POINTWISE := ACCESS | mapjoin(IMMEDIATE, POINTWISE...) | reorder(IMMEDIATE, FIELD...) | IMMEDIATE
MAPREDUCE := POINTWISE | aggregate(IMMEDIATE, IMMEDIATE, POINTWISE, FIELD...)
TABLE := table(IMMEDIATE | DEFERRED, FIELD...)
COMPUTE_QUERY := query(ALIAS, reformat(IMMEDIATE, arg::(REORDER | MAPREDUCE)))
INPUT_QUERY := query(ALIAS, TABLE)
STEP := COMPUTE_QUERY | INPUT_QUERY | produces(ALIAS...)
ROOT := PLAN(STEP...)
"""
function cache_deferred!(ctx, root::LogicNode)
seen::Dict{Any, LogicNode} = Dict{Any, LogicNode}()
return Rewrite(Postwalk(node -> if isdeferred(node)
get!(seen, node.val) do
var = freshen(ctx, :V)
push!(ctx.preamble, :($var = $(node.ex)::$(node.type)))
deferred(var, node.type)
end
end))(root)
@kwdef struct LogicCompiler
mode = :fast
end

function compile(prgm::LogicNode)
ctx = JuliaContext()
freshen(ctx, :prgm)
code = contain(ctx) do ctx_2
prgm = defer_tables(:prgm, prgm)
prgm = cache_deferred!(ctx_2, prgm)
prgm = optimize(prgm)
prgm = format_queries(prgm, true)
FinchCompiler()(prgm)
end
code = pretty(code)
fname = gensym(:compute)
return :(function $fname(prgm)
$code
end) |> striplines
function (ctx::LogicCompiler)(prgm::LogicNode)
prgm = format_queries(prgm, true)
LogicLowerer(mode=ctx.mode)(prgm)
end

codes = Dict()
function compute_impl(prgm, ::FinchCompiler)
function compute_impl(prgm, ::LogicLowerer)
f = get!(codes, get_structure(prgm)) do
eval(compile(prgm))
end
Expand Down
82 changes: 82 additions & 0 deletions src/scheduler/LogicExecutor.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""
defer_tables(root::LogicNode)
Replace immediate tensors with deferred expressions assuming the original program structure
is given as input to the program.
"""
function defer_tables(ex, node::LogicNode)
if @capture node table(~tns::isimmediate, ~idxs...)
table(deferred(:($ex.tns.val), typeof(tns.val)), map(enumerate(node.idxs)) do (i, idx)
defer_tables(:($ex.idxs[$i]), idx)
end)
elseif istree(node)
similarterm(node, operation(node), map(enumerate(node.children)) do (i, child)
defer_tables(:($ex.children[$i]), child)
end)
else
node
end
end

"""
cache_deferred(ctx, root::LogicNode, seen)
Replace deferred expressions with simpler expressions, and cache their evaluation in the preamble.
"""
function cache_deferred!(ctx, root::LogicNode)
seen::Dict{Any, LogicNode} = Dict{Any, LogicNode}()
return Rewrite(Postwalk(node -> if isdeferred(node)
get!(seen, node.val) do
var = freshen(ctx, :V)
push!(ctx.preamble, :($var = $(node.ex)::$(node.type)))
deferred(var, node.type)
end
end))(root)
end

function logic_executor_code(ctx, prgm)
ctx_2 = JuliaContext()
freshen(ctx_2, :prgm)
code = contain(ctx_2) do ctx_3
prgm = defer_tables(:prgm, prgm)
prgm = cache_deferred!(ctx_3, prgm)
ctx(prgm)
end
code = pretty(code)
fname = gensym(:compute)
return :(function $fname(prgm)
$code
end) |> striplines
end

"""
LogicExecutor(ctx)
Executes a logic program by compiling it with the given compiler `ctx`. Compiled
codes are cached, and are only compiled once for each program with the same
structure.
"""
struct LogicExecutor
ctx
end

codes = Dict()
function (ctx::LogicExecutor)(prgm)
f = get!(codes, get_structure(prgm)) do
eval(logic_executor_code(ctx.ctx, prgm))
end
return Base.invokelatest(f, prgm)
end

"""
LogicExecutorCode(ctx)
Return the code that would normally be used by the LogicExecutor to run a program.
"""
struct LogicExecutorCode
ctx
end

function (ctx::LogicExecutorCode)(prgm)
return logic_executor_code(ctx.ctx, prgm)
end
76 changes: 40 additions & 36 deletions src/scheduler/LogicInterpreter.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,3 @@
"""
FinchInterpreter
The finch interpreter is a simple interpreter for finch logic programs. The interpreter is
only capable of executing programs of the form:
REORDER := reorder(relabel(ALIAS, FIELD...), FIELD...)
ACCESS := reorder(relabel(ALIAS, idxs_1::FIELD...), idxs_2::FIELD...) where issubsequence(idxs_1, idxs_2)
POINTWISE := ACCESS | mapjoin(IMMEDIATE, POINTWISE...) | reorder(IMMEDIATE, FIELD...) | IMMEDIATE
MAPREDUCE := POINTWISE | aggregate(IMMEDIATE, IMMEDIATE, POINTWISE, FIELD...)
TABLE := table(IMMEDIATE, FIELD...)
COMPUTE_QUERY := query(ALIAS, reformat(IMMEDIATE, arg::(REORDER | MAPREDUCE)))
INPUT_QUERY := query(ALIAS, TABLE)
STEP := COMPUTE_QUERY | INPUT_QUERY | produces(ALIAS...)
ROOT := PLAN(STEP...)
"""
struct FinchInterpreter
scope::Dict
end

FinchInterpreter() = FinchInterpreter(Dict())

using Finch.FinchNotation: block_instance, declare_instance, call_instance, loop_instance, index_instance, variable_instance, tag_instance, access_instance, assign_instance, literal_instance, yieldbind_instance

function finch_pointwise_logic_to_program(scope, ex)
Expand All @@ -38,7 +17,13 @@ function finch_pointwise_logic_to_program(scope, ex)
end
end

function (ctx::FinchInterpreter)(ex)
@kwdef struct LogicMachine
scope = Dict{Any, Any}()
verbose = false
mode = :fast
end

function (ctx::LogicMachine)(ex)
if ex.kind === alias
ex.scope[ex]
elseif @capture ex query(~lhs, ~rhs)
Expand All @@ -57,8 +42,11 @@ function (ctx::FinchInterpreter)(ex)
body = loop_instance(idx, dimless, body)
end
body = block_instance(declare_instance(res, literal_instance(default(tns.val))), body, yieldbind_instance(res))
#display(body) # wow it's really satisfying to uncomment this and type finch ops at the repl.
execute(body).res
if ctx.verbose
print("Running: ")
display(body)
end
execute(body, mode = ctx.mode).res
elseif @capture ex reformat(~tns, mapjoin(~args...))
z = default(tns.val)
ctx(reformat(tns, aggregate(initwrite(z), immediate(z), mapjoin(args...))))
Expand All @@ -73,8 +61,11 @@ function (ctx::FinchInterpreter)(ex)
body = loop_instance(idx, dimless, body)
end
body = block_instance(declare_instance(res, literal_instance(default(tns.val))), body, yieldbind_instance(res))
#display(body) # wow it's really satisfying to uncomment this and type finch ops at the repl.
execute(body).res
if ctx.verbose
print("Running: ")
display(body)
end
execute(body, mode = ctx.mode).res
elseif @capture ex produces(~args...)
return map(arg -> ctx.scope[arg], args)
elseif @capture ex plan(~head)
Expand All @@ -87,14 +78,27 @@ function (ctx::FinchInterpreter)(ex)
end
end

function normalize_names(ex)
spc = Namespace()
scope = Dict()
normname(sym) = get!(scope, sym) do
if isgensym(sym)
sym = gensymname(sym)
end
freshen(spc, sym)
end
Rewrite(Postwalk(@rule ~a::isalias => alias(normname(a.name))))(ex)
"""
LogicInterpreter(scope = Dict(), verbose = false, mode = :fast)
The LogicInterpreter is a simple interpreter for finch logic programs. The interpreter is
only capable of executing programs of the form:
REORDER := reorder(relabel(ALIAS, FIELD...), FIELD...)
ACCESS := reorder(relabel(ALIAS, idxs_1::FIELD...), idxs_2::FIELD...) where issubsequence(idxs_1, idxs_2)
POINTWISE := ACCESS | mapjoin(IMMEDIATE, POINTWISE...) | reorder(IMMEDIATE, FIELD...) | IMMEDIATE
MAPREDUCE := POINTWISE | aggregate(IMMEDIATE, IMMEDIATE, POINTWISE, FIELD...)
TABLE := table(IMMEDIATE, FIELD...)
COMPUTE_QUERY := query(ALIAS, reformat(IMMEDIATE, arg::(REORDER | MAPREDUCE)))
INPUT_QUERY := query(ALIAS, TABLE)
STEP := COMPUTE_QUERY | INPUT_QUERY | produces(ALIAS...)
ROOT := PLAN(STEP...)
"""
@kwdef struct LogicInterpreter
verbose = false
mode = :fast
end

function (ctx::LogicInterpreter)(prgm)
prgm = format_queries(prgm)
LogicMachine(verbose = ctx.verbose, mode = ctx.mode)(prgm)
end
33 changes: 0 additions & 33 deletions src/scheduler/compute.jl

This file was deleted.

Loading

0 comments on commit c7cd3e9

Please sign in to comment.