Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Connecting to ChainRulesCore for JuliaLang AD compat #652

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion julia/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,18 @@ uuid = "bb22f25d-cb49-471c-b017-930e329a2928"
version = "0.1.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
CombinedParsers = "5ae71ed2-6f8a-4ed1-b94f-e14e8158f19e"

[compat]
ChainRulesCore = "^1.0"
CombinedParsers = "^0.2"
Zygote = "^0.6.22"
julia = "^1.6"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test"]
test = ["Test", "Zygote"]
6 changes: 4 additions & 2 deletions julia/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

DexCall provides a mechanism for calling dex-lang code from JuliaLang.
Three main mechanism are provided for this: `evaluate`, `DexModule` and the `dex_func` string macro.
Two helper methods are also provided: `juliaize` and `NativeFunction`.
Several helper methods are also provided: `juliaize`, `dexize`, and `NativeFunction`.

## `evaluate`: just run a single Dex expression.
`evaluate` takes in a Dex expression as a string and runs it, returning a `Atom` (see below).
Expand Down Expand Up @@ -53,7 +53,7 @@ julia> m.addTwo(m.y)
"[44., 44., 44.]"
```

## Atoms: `juliaize`, `NativeFunction`
## Atoms: `juliaize`, `dexize` and `NativeFunction`

`evaluate` and the contents of a `DexModule` are returned as `Atom`s.
These can be displayed, but not much else.
Expand Down Expand Up @@ -87,6 +87,8 @@ julia> typeof(convert(Int64, m.x))
Int64
```

The inverse of `juliaize` is `dexize`, it is currently very limited.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't given it proper docs yet, because it only does Float32.
It's mostly just for testing purposes.
We need a proper API for this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, we can extend it if you'd like. Just let me know what would be helpful to have.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it is something like create_literal that works like insert except instead of taking an Atom it takes a C compatible value for a Int/Float32/Float64/Array.
Possibly ctypes doesn't allow that directly, so maybe it needs to be wrapped into a tagged union?
I guess maybe accepting the same tagged union that comes out of the atom's pointer makes sense.
(Except right now that doesn't support arrays)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't a function that converts a CAtom into an Atom be sufficient? That's what I would imagine. And then we can add more cases to CAtom if that would be useful.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I think that is basically what I said badly.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, which cases would you like to have? I'm happy to add them for you

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Float32, Float64 (particularly since can't input those as literals #497)
Arrays would be nice, but given we can't currently convert Atom to CAtom for arrays anyway, that doesn't matter so much.
Integer types would be nice, for completeness, but not particularly interesting for AD.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added that in #657


To convert function `Atom`s into something you can execute as if it was a regular julia function use `NativeFunction`.
This will compile it and handing inputs and outputs without needing to del with `Atom`s directly.

Expand Down
4 changes: 3 additions & 1 deletion julia/src/DexCall.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
"Calling Dex from Julia"
module DexCall
using ChainRulesCore
using CombinedParsers
using CombinedParsers.Regexp

export evaluate, DexError, DexModule, juliaize, NativeFunction, @dex_func_str
export evaluate, DexError, DexModule, dexize, juliaize, NativeFunction, @dex_func_str

include("api_types.jl")
include("api.jl")
include("evaluate.jl")
include("native_function.jl")
include("chainrules.jl")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

include("chainrules.jl")


# use this to disable free'ing haskell objects after we have closed the RTS
const NO_FREE = Ref(false)
Expand Down
8 changes: 0 additions & 8 deletions julia/src/api_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,3 @@ function CAtom(atm::Ptr{HsAtom})
iszero(success) && throw_from_dex()
return result[]
end

"""
juliaize(x)

Get the corresponding Julia object from some output of Dex.
"""
juliaize(x::CAtom) = bust_union(x)
juliaize(x::Ptr{HsAtom}) = juliaize(CAtom(x))
apaszke marked this conversation as resolved.
Show resolved Hide resolved
93 changes: 93 additions & 0 deletions julia/src/chainrules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@

function ChainRulesCore.frule((_, ẋs...), f_native::NativeFunction{R}, xs...) where R
f = f_native.atom
env = f.ctx
env = insert(env, "f", f.ptr)

primal_binders = f_native.argument_signature
primal_args_sig = repr_sig(primal_binders)
primal_args = extract_arg_names(primal_binders)

tangent_binders = generate_tangent_binders(primal_binders)
tangent_args_sig = repr_sig(tangent_binders)
tangent_args = extract_arg_names(tangent_binders)

primal_res_sig = repr_result_sig(f_native.result_signature)
dual_res_sig = "($primal_res_sig&$primal_res_sig)"
m = DexModule("""
def frule_inner $primal_args_sig->$tangent_args_sig : $dual_res_sig =
(y, pushforward) = linearize f $primal_args
dy = pushforward $tangent_args
(y, dy)
""",
env
)
# Convert the Atom into `NativeFunction` so can work with any argument type:
frule_inner_native = NativeFunction(m.frule_inner)
return frule_inner_native(xs..., ẋs...)
end

extract_arg_names(binds::Vector{Binder}) = join((bind.name for bind in binds), " ")
"""
Given the `Binder` for the signature of a primal argument/s constructs the matching one
for the tangent.
For now we only support types with tangent type matcing primal type
"""
function generate_tangent_binders(pbinds::Vector{Binder})
return [generate_tangent_binder(pbind) for pbind in pbinds if !pbind.implicit]
end
function generate_tangent_binder(pbind::Binder)
pbind.implicit && throw(DomainError(pbind, "Implict arguments have no tangents"))
return Binder(Symbol(:d, pbind.name), pbind.type, false)
end

function ChainRulesCore.frule((_, ẋ), f::Atom, x::Atom)
ẋ isa Atom || throw(DomainError(ẋ, "Tangent to an Atom must be an Atom"))
env = f.ctx
env = insert(env, "f", f.ptr)
env = insert(env, "dx", ẋ.ptr)
env = insert(env, "x", x.ptr)

m = DexModule(raw"""
(y, pushforward) = linearize f x
dy = pushforward dx
""",
env
)
return m.y, m.dy
end

function ChainRulesCore.rrule(f::Atom, x::Atom)
env = f.ctx
env = insert(env, "f", f.ptr)
env = insert(env, "x", x.ptr)

m = DexModule(raw"""
(y, pushforward) = linearize f x
pullback = transposeLinear pushforward
""",
env
)

# It is important that we close over `m` as otherwise the env may be destroyed by GC
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is no longer the case since right now that finalizers is commented out.
And before this is merged we may way move to working out how to attach the finalizer to the context more directly.

pullback(x̄::Atom)= (NoTangent(), m.pullback(x̄))
return m.y, pullback
end

ChainRulesCore.frule((_, ẋ), ::typeof(juliaize), x) = juliaize(x), juliaize(ẋ)
function ChainRulesCore.rrule(::typeof(juliaize), x::Atom)
env= x.ctx

# pullback must take a julia typed cotangent and give back a dex typed cotangent
juliaize_pullback(ȳ) = (NoTangent(), dexize(ȳ, env))
return juliaize(x), juliaize_pullback
end


ChainRulesCore.frule((_, ẋ), ::typeof(dexize), x) = dexize(x), dexize(ẋ)
function ChainRulesCore.rrule(::typeof(dexize), x)
# pullback must take a dex typed cotangent and give back a julia typed cotangent
dexize_pullback(ȳ) = (NoTangent(), juliaize(ȳ))
return dexize(x), dexize_pullback
end

40 changes: 37 additions & 3 deletions julia/src/evaluate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,37 @@ end

Base.show(io::IO, atom::Atom) = show(io, print(atom.ptr))

"""
juliaize(x)

Get the corresponding Julia object from some output of Dex.
"""
juliaize(x::CAtom) = bust_union(x)
juliaize(x::Ptr{HsAtom}) = juliaize(CAtom(x))
juliaize(x::Atom) = juliaize(x.ptr)
Base.convert(::Type{T}, atom::Atom) where {T<:Number} = convert(T, juliaize(atom))

"""
dexize(x)

Get the corresponding Dex object from some output of Julia.

NB: this is currently a hack that goes via string processing.
"""
function dexize(x::Float32, _module=PRELUDE, env=_module)
Copy link
Contributor Author

@oxinabox oxinabox Sep 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect we do not need to take a env and a _module argument.
We are making literals, we just need one for where the literal will exist?
I don't really understand the difference between them

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _module and env parameters are a bit of a hack that I added for Python bindings. The rough idea was that module is the module an output atom declares itself to be defined in, while env is the scope that's really used to evaluate the expression. This is used in the __call__ implementation where we temporarily extend the prelude with new names that refer to arguments, but then we want to pretend that the result is still defined in the original module that doesn't have those dummies. But now that I think about it, it's only well defined for non-dependent functions, so we should find a different workaround...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In any case, I think you can safely ignore env and just use the _module.

isnan(x) && return evaluate("nan", _module, env)
x === Inf32 && return evaluate("infinity", _module, env)
x === -Inf32 && return evaluate("-infinity", _module, env)

str = repr(x)
if endswith(str, "f0")
evaluate(str[1:end-2], _module, env)
else
# convert "123f45" into "123 * (intpow 10.0 45)"
evaluate(replace(str, "f"=> " * (intpow 10.0 ") * ")", _module, env)
end
end


function (self::Atom)(args...)
# TODO: Make those calls more hygenic
Expand Down Expand Up @@ -60,12 +88,18 @@ julia> m.y
"84"
```
"""
function DexModule(source::AbstractString)
ctx = dex_eval(PRELUDE, source)
function DexModule(source::AbstractString, parent_ctx=PRELUDE)
ctx = dex_eval(parent_ctx, source)
ctx == C_NULL && throw_from_dex()
m = DexModule(ctx)
finalizer(m) do _m
destroy_context(getfield(_m, :ctx))
# TODO: Undo commenting this out. But for now this causes a lot of problems.
# DexModule will often go out of scope, while a Atom attached to that context still
# exists. Possibly we need to make the ctx a mutable struct everywhere, and then
# attach the finalizer there.
#(also will let us delete some manual destroys in other palces)

#destroy_context(getfield(_m, :ctx))
Comment on lines +96 to +102
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs resolving

end
return m
end
Expand Down
36 changes: 32 additions & 4 deletions julia/src/native_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,32 @@ end
ArrayBuilder{T}(size) where T = ArrayBuilder{T,length(size)}(size)


"representation of this as it would appear in a dex `def` function signature"
repr_sig(binders::AbstractVector{Binder}) = join(Iterators.map(repr_sig, binders), "->")
function repr_sig(binder::Binder)
str = "($(binder.name):$(repr_sig(binder.type)))"
if binder.implicit
str *= "?"
end
return str
end
function repr_sig(builder::ArrayBuilder{T}) where {T}
sizes_repr = Iterators.map(size_element -> "Fin $size_element", builder.size)
return join(sizes_repr, "=>") * "=>" * repr_sig(T)
end

# For most types, like Int32 and Float64 Dex and Julia use identical names
# and for Symbols represent implicts they are also are made into strings by `string`
repr_sig(x) = string(x)
# TODO: Word8, Char etc?


"representation of this as it would appear in the result part of a dex `def` function signature"
repr_result_sig(x) = repr_sig(x)
function repr_result_sig(binders::AbstractVector{Binder})
return "(" * join(Iterators.map(repr_result_sig, binders), "&") * ")"
end
repr_result_sig(binder::Binder) = repr_result_sig(binder.type)

"""
NativeFunction{R}
Expand All @@ -98,14 +124,14 @@ Usually constructed using [`@dex_func_str`](@ref),
or via `NativeFunction(atom)` on some [`DexCall.Atom`](@ref).
"""
struct NativeFunction{R} <: Function
c_func_ptr::Ptr{Nothing}
atom::Atom # non-compiled Atom form of this function
c_func_ptr::Ptr{Nothing} # compiled C API form of this function
argument_signature::Vector{Binder}
result_signature::Vector{Binder}
end

NativeFunction(atom::Atom, jit=JIT) = NativeFunction(atom.ptr, atom.ctx, jit)
function NativeFunction(atom::Ptr{HsAtom}, ctx=PRELUDE, jit=JIT)
c_func_ptr = compile(ctx, atom, jit)
function NativeFunction(atom::Atom, ctx=atom.ctx, jit=JIT)
c_func_ptr = compile(ctx, atom.ptr, jit)
sig_ptr = get_function_signature(c_func_ptr, jit)
sig_ptr == C_NULL && error("Failed to retrieve the function signature")

Expand All @@ -114,6 +140,7 @@ function NativeFunction(atom::Ptr{HsAtom}, ctx=PRELUDE, jit=JIT)
result_signature = parse_sig(signature.res)
R = result_type(result_signature)
f = NativeFunction{R}(
atom,
c_func_ptr,
parse_sig(signature.arg),
result_signature
Expand Down Expand Up @@ -295,6 +322,7 @@ function parse_sig(sig)
parser("i32"=>Int32),
parser("i64"=>Int64),
parser("i8"=>Int8),
# TODO: Word8, Char etc?
)
size_ele = NumericParser(Int) | name
sizes = join(Repeat(size_ele),",")
Expand Down
52 changes: 52 additions & 0 deletions julia/test/chainrules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
const double_dex = evaluate(raw"\x:Float. 2.0 * x")

@testset "frule: dexize, evaluate, juliaize" begin
a, ȧ = frule((NoTangent(), 10f0), dexize, 1.5f0)
b, ḃ = frule((NoTangent(), ȧ), double_dex, a)
c, ċ = frule((NoTangent(), ḃ), juliaize, b)
@test c === 3.0f0
@test ċ === 20f0
end

@testset "rrule: dexize, evaluate, juliaize" begin
x = 1.5f0
a, a_pb = rrule(dexize, x)
b, b_pb = rrule(double_dex, a)
c, c_pb = rrule(juliaize, b)

@test c === 3.0f0
c̄ = 10f0
_, b̄ = c_pb(c̄)
_, ā = b_pb(b̄)
_, x̄ = a_pb(ā)

@test x̄ === 20f0
end

@testset "Integration Test: Zygote.jl" begin
double_via_dex = juliaize ∘ double_dex ∘ dexize
y, pb= Zygote.pullback(double_via_dex, 1.5f0)
@test y == 3f0
@test pb(1f0) == (2f0,)
end


@testset "frule NativeFunction" begin
dex_func"decimate_dex = \x:Float. x/10.0"
@test frule((NoTangent(), 50f0), decimate_dex, 150f0) === (15f0, 5f0)

dex_func"sum3_dex = \x:(Fin 3=>Float). sum x"
@test frule((NoTangent(), [1f0, 10f0, 100f0]), sum3_dex, [1f0, 2f0, 3f0]) === (6f0, 111f0)

dex_func"twovec_dex = \x:(Float32). [x,x]"
twovec_dex(1f2)
@test frule((NoTangent(), 10f0), twovec_dex, 4f0) == ([4f0, 4f0], [10f0,10f0])

# With Implicts
dex_func"def mysum_dex (arg0:Int32)?-> (arg1:Fin arg0 => Float32) : Float32 = sum arg1"
@test_broken frule((NoTangent(), [1f0, 10f0, 100f0, 1000f0]), mysum_dex, [1f0, 2f0, 3f0, 4f0]) === (10f0, 1111f0)

# With multiple arguments
dex_func"add_dex = \x:Float32 y:Float32. x+y"
@test_broken frule((NoTangent(), 10f0, 100f0), add_dex, 1f0, 2f0)
end
10 changes: 10 additions & 0 deletions julia/test/evaluate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,16 @@
@test juliaize(evaluate("IToW8 65")) === Int8(65)
end

@testset "dexize" begin
@test juliaize(dexize(0f0)) === 0f0
@test juliaize(dexize(42f0)) === 42f0
@test juliaize(dexize(123f15)) === 123f15
@test dexize(123f15) isa DexCall.Atom
@test isnan(juliaize(dexize(NaN32)))
@test (juliaize(dexize(Inf32))) == Inf32
@test (juliaize(dexize(-Inf32))) == -Inf32
end

@testset "Atom function call" begin
m = DexModule("""
def addOne (x: Float) : Float = x + 1.0
Expand Down
15 changes: 15 additions & 0 deletions julia/test/native_function.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

@testset "native_function.jl" begin
@testset "signature parser" begin
# Testing Implementation details, can remove if implementation changes
@testset "$example" for example in (
"arg0:f32",
"arg0:f32,arg1:f32",
Expand All @@ -17,6 +18,20 @@
@test DexCall.parse_sig(example) isa Vector{DexCall.Binder}
end
end

@testset "signature repr" begin
# Testing Implementation details, can remove if implementation changes
as_in_dex_sig = DexCall.repr_sig ∘ DexCall.parse_sig
@test as_in_dex_sig("arg0:f32") == "(arg0:Float32)"
@test as_in_dex_sig("arg0:f32,arg1:f32") == "(arg0:Float32)->(arg1:Float32)"
@test as_in_dex_sig("arg0:i64,arg1:i32") == "(arg0:Int64)->(arg1:Int32)"
@test as_in_dex_sig("arg0:f32[10]") == "(arg0:Fin 10=>Float32)"
@test as_in_dex_sig("?arg0:i32,arg1:f32[arg0]") == "(arg0:Int32)?->(arg1:Fin arg0=>Float32)"
@test as_in_dex_sig("arg2:f32[arg0]") == "(arg2:Fin arg0=>Float32)"
@test as_in_dex_sig("?arg0:i32,?arg1:i32,arg2:f32[arg0,arg1]") == "(arg0:Int32)?->(arg1:Int32)?->(arg2:Fin arg0=>Fin arg1=>Float32)"
@test as_in_dex_sig("arg3:f32[arg1,arg0]") == "(arg3:Fin arg1=>Fin arg0=>Float32)"
@test as_in_dex_sig("arg0:f32,?arg1:i32,arg2:f32[arg1]") == "(arg0:Float32)->(arg1:Int32)?->(arg2:Fin arg1=>Float32)"
end

@testset "dex_func anon funcs" begin
@test dex_func"\x:Float. exp x"(0f0) === 1f0
Expand Down
Loading