-
Notifications
You must be signed in to change notification settings - Fork 107
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP: Connecting to ChainRulesCore for JuliaLang AD compat #652
base: main
Are you sure you want to change the base?
Changes from all commits
c5104b4
42cbad5
37372cb
e6c5759
b2f188c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,16 @@ | ||
"Calling Dex from Julia" | ||
module DexCall | ||
using ChainRulesCore | ||
using CombinedParsers | ||
using CombinedParsers.Regexp | ||
|
||
export evaluate, DexError, DexModule, juliaize, NativeFunction, @dex_func_str | ||
export evaluate, DexError, DexModule, dexize, juliaize, NativeFunction, @dex_func_str | ||
|
||
include("api_types.jl") | ||
include("api.jl") | ||
include("evaluate.jl") | ||
include("native_function.jl") | ||
include("chainrules.jl") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
# use this to disable free'ing haskell objects after we have closed the RTS | ||
const NO_FREE = Ref(false) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
|
||
function ChainRulesCore.frule((_, ẋs...), f_native::NativeFunction{R}, xs...) where R | ||
f = f_native.atom | ||
env = f.ctx | ||
env = insert(env, "f", f.ptr) | ||
|
||
primal_binders = f_native.argument_signature | ||
primal_args_sig = repr_sig(primal_binders) | ||
primal_args = extract_arg_names(primal_binders) | ||
|
||
tangent_binders = generate_tangent_binders(primal_binders) | ||
tangent_args_sig = repr_sig(tangent_binders) | ||
tangent_args = extract_arg_names(tangent_binders) | ||
|
||
primal_res_sig = repr_result_sig(f_native.result_signature) | ||
dual_res_sig = "($primal_res_sig&$primal_res_sig)" | ||
m = DexModule(""" | ||
def frule_inner $primal_args_sig->$tangent_args_sig : $dual_res_sig = | ||
(y, pushforward) = linearize f $primal_args | ||
dy = pushforward $tangent_args | ||
(y, dy) | ||
""", | ||
env | ||
) | ||
# Convert the Atom into `NativeFunction` so can work with any argument type: | ||
frule_inner_native = NativeFunction(m.frule_inner) | ||
return frule_inner_native(xs..., ẋs...) | ||
end | ||
|
||
extract_arg_names(binds::Vector{Binder}) = join((bind.name for bind in binds), " ") | ||
""" | ||
Given the `Binder` for the signature of a primal argument/s constructs the matching one | ||
for the tangent. | ||
For now we only support types with tangent type matcing primal type | ||
""" | ||
function generate_tangent_binders(pbinds::Vector{Binder}) | ||
return [generate_tangent_binder(pbind) for pbind in pbinds if !pbind.implicit] | ||
end | ||
function generate_tangent_binder(pbind::Binder) | ||
pbind.implicit && throw(DomainError(pbind, "Implict arguments have no tangents")) | ||
return Binder(Symbol(:d, pbind.name), pbind.type, false) | ||
end | ||
|
||
function ChainRulesCore.frule((_, ẋ), f::Atom, x::Atom) | ||
ẋ isa Atom || throw(DomainError(ẋ, "Tangent to an Atom must be an Atom")) | ||
env = f.ctx | ||
env = insert(env, "f", f.ptr) | ||
env = insert(env, "dx", ẋ.ptr) | ||
env = insert(env, "x", x.ptr) | ||
|
||
m = DexModule(raw""" | ||
(y, pushforward) = linearize f x | ||
dy = pushforward dx | ||
""", | ||
env | ||
) | ||
return m.y, m.dy | ||
end | ||
|
||
function ChainRulesCore.rrule(f::Atom, x::Atom) | ||
env = f.ctx | ||
env = insert(env, "f", f.ptr) | ||
env = insert(env, "x", x.ptr) | ||
|
||
m = DexModule(raw""" | ||
(y, pushforward) = linearize f x | ||
pullback = transposeLinear pushforward | ||
""", | ||
env | ||
) | ||
|
||
# It is important that we close over `m` as otherwise the env may be destroyed by GC | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is no longer the case since right now that finalizers is commented out. |
||
pullback(x̄::Atom)= (NoTangent(), m.pullback(x̄)) | ||
return m.y, pullback | ||
end | ||
|
||
ChainRulesCore.frule((_, ẋ), ::typeof(juliaize), x) = juliaize(x), juliaize(ẋ) | ||
function ChainRulesCore.rrule(::typeof(juliaize), x::Atom) | ||
env= x.ctx | ||
|
||
# pullback must take a julia typed cotangent and give back a dex typed cotangent | ||
juliaize_pullback(ȳ) = (NoTangent(), dexize(ȳ, env)) | ||
return juliaize(x), juliaize_pullback | ||
end | ||
|
||
|
||
ChainRulesCore.frule((_, ẋ), ::typeof(dexize), x) = dexize(x), dexize(ẋ) | ||
function ChainRulesCore.rrule(::typeof(dexize), x) | ||
# pullback must take a dex typed cotangent and give back a julia typed cotangent | ||
dexize_pullback(ȳ) = (NoTangent(), juliaize(ȳ)) | ||
return dexize(x), dexize_pullback | ||
end | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,9 +13,37 @@ end | |
|
||
Base.show(io::IO, atom::Atom) = show(io, print(atom.ptr)) | ||
|
||
""" | ||
juliaize(x) | ||
|
||
Get the corresponding Julia object from some output of Dex. | ||
""" | ||
juliaize(x::CAtom) = bust_union(x) | ||
juliaize(x::Ptr{HsAtom}) = juliaize(CAtom(x)) | ||
juliaize(x::Atom) = juliaize(x.ptr) | ||
Base.convert(::Type{T}, atom::Atom) where {T<:Number} = convert(T, juliaize(atom)) | ||
|
||
""" | ||
dexize(x) | ||
|
||
Get the corresponding Dex object from some output of Julia. | ||
|
||
NB: this is currently a hack that goes via string processing. | ||
""" | ||
function dexize(x::Float32, _module=PRELUDE, env=_module) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suspect we do not need to take a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In any case, I think you can safely ignore |
||
isnan(x) && return evaluate("nan", _module, env) | ||
x === Inf32 && return evaluate("infinity", _module, env) | ||
x === -Inf32 && return evaluate("-infinity", _module, env) | ||
|
||
str = repr(x) | ||
if endswith(str, "f0") | ||
evaluate(str[1:end-2], _module, env) | ||
else | ||
# convert "123f45" into "123 * (intpow 10.0 45)" | ||
evaluate(replace(str, "f"=> " * (intpow 10.0 ") * ")", _module, env) | ||
end | ||
end | ||
|
||
|
||
function (self::Atom)(args...) | ||
# TODO: Make those calls more hygenic | ||
|
@@ -60,12 +88,18 @@ julia> m.y | |
"84" | ||
``` | ||
""" | ||
function DexModule(source::AbstractString) | ||
ctx = dex_eval(PRELUDE, source) | ||
function DexModule(source::AbstractString, parent_ctx=PRELUDE) | ||
ctx = dex_eval(parent_ctx, source) | ||
ctx == C_NULL && throw_from_dex() | ||
m = DexModule(ctx) | ||
finalizer(m) do _m | ||
destroy_context(getfield(_m, :ctx)) | ||
# TODO: Undo commenting this out. But for now this causes a lot of problems. | ||
# DexModule will often go out of scope, while a Atom attached to that context still | ||
# exists. Possibly we need to make the ctx a mutable struct everywhere, and then | ||
# attach the finalizer there. | ||
#(also will let us delete some manual destroys in other palces) | ||
|
||
#destroy_context(getfield(_m, :ctx)) | ||
Comment on lines
+96
to
+102
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This needs resolving |
||
end | ||
return m | ||
end | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
const double_dex = evaluate(raw"\x:Float. 2.0 * x") | ||
|
||
@testset "frule: dexize, evaluate, juliaize" begin | ||
a, ȧ = frule((NoTangent(), 10f0), dexize, 1.5f0) | ||
b, ḃ = frule((NoTangent(), ȧ), double_dex, a) | ||
c, ċ = frule((NoTangent(), ḃ), juliaize, b) | ||
@test c === 3.0f0 | ||
@test ċ === 20f0 | ||
end | ||
|
||
@testset "rrule: dexize, evaluate, juliaize" begin | ||
x = 1.5f0 | ||
a, a_pb = rrule(dexize, x) | ||
b, b_pb = rrule(double_dex, a) | ||
c, c_pb = rrule(juliaize, b) | ||
|
||
@test c === 3.0f0 | ||
c̄ = 10f0 | ||
_, b̄ = c_pb(c̄) | ||
_, ā = b_pb(b̄) | ||
_, x̄ = a_pb(ā) | ||
|
||
@test x̄ === 20f0 | ||
end | ||
|
||
@testset "Integration Test: Zygote.jl" begin | ||
double_via_dex = juliaize ∘ double_dex ∘ dexize | ||
y, pb= Zygote.pullback(double_via_dex, 1.5f0) | ||
@test y == 3f0 | ||
@test pb(1f0) == (2f0,) | ||
end | ||
|
||
|
||
@testset "frule NativeFunction" begin | ||
dex_func"decimate_dex = \x:Float. x/10.0" | ||
@test frule((NoTangent(), 50f0), decimate_dex, 150f0) === (15f0, 5f0) | ||
|
||
dex_func"sum3_dex = \x:(Fin 3=>Float). sum x" | ||
@test frule((NoTangent(), [1f0, 10f0, 100f0]), sum3_dex, [1f0, 2f0, 3f0]) === (6f0, 111f0) | ||
|
||
dex_func"twovec_dex = \x:(Float32). [x,x]" | ||
twovec_dex(1f2) | ||
@test frule((NoTangent(), 10f0), twovec_dex, 4f0) == ([4f0, 4f0], [10f0,10f0]) | ||
|
||
# With Implicts | ||
dex_func"def mysum_dex (arg0:Int32)?-> (arg1:Fin arg0 => Float32) : Float32 = sum arg1" | ||
@test_broken frule((NoTangent(), [1f0, 10f0, 100f0, 1000f0]), mysum_dex, [1f0, 2f0, 3f0, 4f0]) === (10f0, 1111f0) | ||
|
||
# With multiple arguments | ||
dex_func"add_dex = \x:Float32 y:Float32. x+y" | ||
@test_broken frule((NoTangent(), 10f0, 100f0), add_dex, 1f0, 2f0) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't given it proper docs yet, because it only does Float32.
It's mostly just for testing purposes.
We need a proper API for this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, we can extend it if you'd like. Just let me know what would be helpful to have.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess it is something like
create_literal
that works likeinsert
except instead of taking anAtom
it takes a C compatible value for a Int/Float32/Float64/Array.Possibly
ctypes
doesn't allow that directly, so maybe it needs to be wrapped into a tagged union?I guess maybe accepting the same tagged union that comes out of the atom's pointer makes sense.
(Except right now that doesn't support arrays)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't a function that converts a
CAtom
into anAtom
be sufficient? That's what I would imagine. And then we can add more cases toCAtom
if that would be useful.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, I think that is basically what I said badly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool, which cases would you like to have? I'm happy to add them for you
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Float32
,Float64
(particularly since can't input those as literals #497)Arrays would be nice, but given we can't currently convert
Atom
toCAtom
for arrays anyway, that doesn't matter so much.Integer types would be nice, for completeness, but not particularly interesting for AD.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added that in #657