Skip to content

Commit

Permalink
Merge pull request #114 from theogf/test_AD
Browse files Browse the repository at this point in the history
Series of tests for AD
  • Loading branch information
theogf authored Jun 15, 2020
2 parents 3b0cf61 + e94973e commit fb37557
Show file tree
Hide file tree
Showing 48 changed files with 375 additions and 229 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CompatHelper.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ jobs:
- name: CompatHelper.main()
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: julia -e 'using CompatHelper; CompatHelper.main()'
run: julia -e 'using CompatHelper; CompatHelper.main(; subdirs = ["", "test"])'
12 changes: 0 additions & 12 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,3 @@ StatsBase = "0.32, 0.33"
StatsFuns = "0.8, 0.9"
ZygoteRules = "0.2"
julia = "1.3"

[extras]
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Random", "Test", "FiniteDifferences", "Zygote", "PDMats", "Kronecker", "Flux"]
4 changes: 2 additions & 2 deletions src/KernelFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ export spectral_mixture_kernel, spectral_mixture_product_kernel
using Compat
using Requires
using Distances, LinearAlgebra
using SpecialFunctions: logabsgamma, besselk
using ZygoteRules: @adjoint
using SpecialFunctions: logabsgamma, besselk, polygamma
using ZygoteRules: @adjoint, pullback
using StatsFuns: logtwo
using InteractiveUtils: subtypes
using StatsBase
Expand Down
11 changes: 5 additions & 6 deletions src/basekernels/matern.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,11 @@ end

@inline function kappa::MaternKernel, d::Real)
ν = first.ν)
iszero(d) ? one(d) :
exp(
(one(d) - ν) * logtwo - logabsgamma(ν)[1] +
ν * log(sqrt(2ν) * d) +
log(besselk(ν, sqrt(2ν) * d))
)
iszero(d) ? one(d) : _matern(ν, d)
end

function _matern::Real, d::Real)
exp((one(d) - ν) * logtwo - loggamma(ν) + ν * log(sqrt(2ν) * d) + log(besselk(ν, sqrt(2ν) * d)))
end

metric(::MaternKernel) = Euclidean()
Expand Down
6 changes: 4 additions & 2 deletions src/distances/delta.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
struct Delta <: Distances.PreMetric
end

@inline function Distances._evaluate(::Delta,a::AbstractVector{T},b::AbstractVector{T}) where {T}
@inline function Distances._evaluate(::Delta, a::AbstractVector, b::AbstractVector) where {T}
@boundscheck if length(a) != length(b)
throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b))."))
end
return a == b
end

Distances.result_type(::Delta, Ta::Type, Tb::Type) = promote_type(Ta, Tb)

@inline (dist::Delta)(a::AbstractArray, b::AbstractArray) = Distances._evaluate(dist, a, b)
@inline (dist::Delta)(a::Number,b::Number) = a == b
@inline (dist::Delta)(a::Number, b::Number) = a == b
4 changes: 3 additions & 1 deletion src/distances/dotproduct.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
struct DotProduct <: Distances.PreMetric end
# struct DotProduct <: Distances.UnionSemiMetric end

@inline function Distances._evaluate(::DotProduct, a::AbstractVector{T}, b::AbstractVector{T}) where {T}
@inline function Distances._evaluate(::DotProduct, a::AbstractVector, b::AbstractVector)
@boundscheck if length(a) != length(b)
throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b))."))
end
return dot(a,b)
end

Distances.result_type(::DotProduct, Ta::Type, Tb::Type) = promote_type(Ta, Tb)

@inline Distances.eval_op(::DotProduct, a::Real, b::Real) = a * b
@inline (dist::DotProduct)(a::AbstractArray,b::AbstractArray) = Distances._evaluate(dist, a, b)
@inline (dist::DotProduct)(a::Number,b::Number) = a * b
4 changes: 3 additions & 1 deletion src/distances/sinus.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ Distances.parameters(d::Sinus) = d.r
@inline (dist::Sinus)(a::AbstractArray, b::AbstractArray) = Distances._evaluate(dist, a, b)
@inline (dist::Sinus)(a::Number, b::Number) = abs2(sinpi(a - b) / first(dist.r))

@inline function Distances._evaluate(d::Sinus, a::AbstractVector{T}, b::AbstractVector{T}) where {T}
Distances.result_type(::Sinus{T}, Ta::Type, Tb::Type) where {T} = promote_type(T, Ta, Tb)

@inline function Distances._evaluate(d::Sinus, a::AbstractVector, b::AbstractVector) where {T}
@boundscheck if (length(a) != length(b)) || length(a) != length(d.r)
throw(DimensionMismatch("Dimensions of the inputs are not matching : a = $(length(a)), b = $(length(b)), r = $(length(d.r))"))
end
Expand Down
6 changes: 3 additions & 3 deletions src/transform/ardtransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ dim(t::ARDTransform) = length(t.v)
(t::ARDTransform)(x::Real) = first(t.v) * x
(t::ARDTransform)(x) = t.v .* x

Base.map(t::ARDTransform, x::AbstractVector{<:Real}) = t.v' .* x
Base.map(t::ARDTransform, x::ColVecs) = ColVecs(t.v .* x.X)
Base.map(t::ARDTransform, x::RowVecs) = RowVecs(t.v' .* x.X)
_map(t::ARDTransform, x::AbstractVector{<:Real}) = t.v' .* x
_map(t::ARDTransform, x::ColVecs) = ColVecs(t.v .* x.X)
_map(t::ARDTransform, x::RowVecs) = RowVecs(t.v' .* x.X)

Base.isequal(t::ARDTransform, t2::ARDTransform) = isequal(t.v, t2.v)

Expand Down
2 changes: 1 addition & 1 deletion src/transform/chaintransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Base.:∘(tc::ChainTransform, t::Transform) = ChainTransform(vcat(t, tc.transfor

(t::ChainTransform)(x) = foldl((x, t) -> t(x), t.transforms; init=x)

function Base.map(t::ChainTransform, x::AbstractVector)
function _map(t::ChainTransform, x::AbstractVector)
return foldl((x, t) -> map(t, x), t.transforms; init=x)
end

Expand Down
6 changes: 3 additions & 3 deletions src/transform/functiontransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ end

(t::FunctionTransform)(x) = t.f(x)

Base.map(t::FunctionTransform, x::AbstractVector{<:Real}) = map(t.f, x)
Base.map(t::FunctionTransform, x::ColVecs) = ColVecs(mapslices(t.f, x.X; dims=1))
Base.map(t::FunctionTransform, x::RowVecs) = RowVecs(mapslices(t.f, x.X; dims=2))
_map(t::FunctionTransform, x::AbstractVector{<:Real}) = map(t.f, x)
_map(t::FunctionTransform, x::ColVecs) = ColVecs(mapslices(t.f, x.X; dims=1))
_map(t::FunctionTransform, x::RowVecs) = RowVecs(mapslices(t.f, x.X; dims=2))

duplicate(t::FunctionTransform,f) = FunctionTransform(f)

Expand Down
6 changes: 3 additions & 3 deletions src/transform/lineartransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ end
(t::LinearTransform)(x::Real) = vec(t.A * x)
(t::LinearTransform)(x::AbstractVector{<:Real}) = t.A * x

Base.map(t::LinearTransform, x::AbstractVector{<:Real}) = ColVecs(t.A * x')
Base.map(t::LinearTransform, x::ColVecs) = ColVecs(t.A * x.X)
Base.map(t::LinearTransform, x::RowVecs) = RowVecs(x.X * t.A')
_map(t::LinearTransform, x::AbstractVector{<:Real}) = ColVecs(t.A * x')
_map(t::LinearTransform, x::ColVecs) = ColVecs(t.A * x.X)
_map(t::LinearTransform, x::RowVecs) = RowVecs(x.X * t.A')

function Base.show(io::IO, t::LinearTransform)
print(io::IO, "Linear transform (size(A) = ", size(t.A), ")")
Expand Down
6 changes: 3 additions & 3 deletions src/transform/scaletransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ set!(t::ScaleTransform,ρ::Real) = t.s .= [ρ]

(t::ScaleTransform)(x) = first(t.s) .* x

Base.map(t::ScaleTransform, x::AbstractVector{<:Real}) = first(t.s) .* x
Base.map(t::ScaleTransform, x::ColVecs) = ColVecs(first(t.s) .* x.X)
Base.map(t::ScaleTransform, x::RowVecs) = RowVecs(first(t.s) .* x.X)
_map(t::ScaleTransform, x::AbstractVector{<:Real}) = first(t.s) .* x
_map(t::ScaleTransform, x::ColVecs) = ColVecs(first(t.s) .* x.X)
_map(t::ScaleTransform, x::RowVecs) = RowVecs(first(t.s) .* x.X)

Base.isequal(t::ScaleTransform,t2::ScaleTransform) = isequal(first(t.s),first(t2.s))

Expand Down
4 changes: 2 additions & 2 deletions src/transform/selecttransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ duplicate(t::SelectTransform,θ) = t

(t::SelectTransform)(x::AbstractVector) = view(x, t.select)

Base.map(t::SelectTransform, x::ColVecs) = ColVecs(view(x.X, t.select, :))
Base.map(t::SelectTransform, x::RowVecs) = RowVecs(view(x.X, :, t.select))
_map(t::SelectTransform, x::ColVecs) = ColVecs(view(x.X, t.select, :))
_map(t::SelectTransform, x::RowVecs) = RowVecs(view(x.X, :, t.select))

Base.show(io::IO, t::SelectTransform) = print(io, "Select Transform (dims: ", t.select, ")")
8 changes: 2 additions & 6 deletions src/transform/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,8 @@ include("functiontransform.jl")
include("selecttransform.jl")
include("chaintransform.jl")

"""
apply(t::Transform, x; obsdim::Int=defaultobs)

Apply the transform `t` vector-wise on the array `x`
"""
apply
Base.map(t::Transform, x::AbstractVector) = _map(t, x)

"""
IdentityTransform()
Expand All @@ -20,7 +16,7 @@ Return exactly the input
struct IdentityTransform <: Transform end

(t::IdentityTransform)(x) = x
Base.map(::IdentityTransform, x::AbstractVector) = x
_map(::IdentityTransform, x::AbstractVector) = x

### TODO Maybe defining adjoints could help but so far it's not working

Expand Down
3 changes: 2 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
hadamard(x, y) = x .* y

loggamma(x) = first(logabsgamma(x))

# Macro for checking arguments
macro check_args(K, param, cond, desc=string(cond))
quote
Expand Down Expand Up @@ -124,4 +126,3 @@ function validate_dims(x::AbstractVector, y::AbstractVector)
))
end
end

83 changes: 76 additions & 7 deletions src/zygote_adjoints.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,78 @@
## Adjoints Delta
@adjoint function evaluate(s::Delta, x::AbstractVector, y::AbstractVector)
evaluate(s, x, y), Δ -> begin
(nothing, nothing, nothing)
end
end

@adjoint function pairwise(d::Delta, X::AbstractMatrix, Y::AbstractMatrix; dims=2)
D = pairwise(d, X, Y; dims = dims)
if dims == 1
return D, Δ -> (nothing, nothing, nothing)
else
return D, Δ -> (nothing, nothing, nothing)
end
end

@adjoint function pairwise(d::Delta, X::AbstractMatrix; dims=2)
D = pairwise(d, X; dims = dims)
if dims == 1
return D, Δ -> (nothing, nothing)
else
return D, Δ -> (nothing, nothing)
end
end

## Adjoints DotProduct
@adjoint function evaluate(s::DotProduct, x::AbstractVector, y::AbstractVector)
dot(x, y), Δ -> begin
(nothing, Δ .* y, Δ .* x)
end
end

@adjoint function pairwise(d::DotProduct, X::AbstractMatrix, Y::AbstractMatrix; dims=2)
D = pairwise(d, X, Y; dims = dims)
if dims == 1
return D, Δ -> (nothing, Δ * Y, (X' * Δ)')
else
return D, Δ -> (nothing, (Δ * Y')', X * Δ)
end
end

@adjoint function pairwise(d::DotProduct, X::AbstractMatrix; dims=2)
D = pairwise(d, X; dims = dims)
if dims == 1
return D, Δ -> (nothing, 2 * Δ * X)
else
return D, Δ -> (nothing, 2 * X * Δ)
end
end

## Adjoints Sinus
@adjoint function evaluate(s::Sinus, x::AbstractVector, y::AbstractVector)
d = (x - y)
sind = sinpi.(d)
val = sum(abs2, sind ./ s.r)
gradx = 2π .* cospi.(d) .* sind ./ (s.r .^ 2)
val, Δ -> begin
((r = -2Δ .* abs2.(sind) ./ s.r,), Δ * gradx, - Δ * gradx)
end
end

@adjoint function loggamma(x)
first(logabsgamma(x)) , Δ ->.* polygamma(0, x), )
end

@adjoint function kappa::MaternKernel, d::Real)
ν = first.ν)
val, grad = pullback(_matern, ν, d)
return ((iszero(d) ? one(d) : val),
Δ -> begin
= grad(Δ)
return ((ν = [∇[1]],), iszero(d) ? zero(d) : ∇[2])
end)
end

@adjoint function ColVecs(X::AbstractMatrix)
back::NamedTuple) =.X,)
back::AbstractMatrix) = (Δ,)
Expand All @@ -22,10 +91,10 @@ end
return RowVecs(X), back
end

# @adjoint function evaluate(s::Sinus, x::AbstractVector, y::AbstractVector)
# d = evaluate(s, x, y)
# s = sum(sin.(π*(x-y)))
# d, Δ -> begin
# (Sinus(Δ ./ s.r), 2Δ .* cos.(x - y) * d, -2Δ .* cos.(x - y) * d)
# end
# end
@adjoint function Base.map(t::Transform, X::ColVecs)
pullback(_map, t, X)
end

@adjoint function Base.map(t::Transform, X::RowVecs)
pullback(_map, t, X)
end
24 changes: 24 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
[deps]
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Distances = "0.9"
FiniteDifferences = "0.10"
Flux = "0.10"
ForwardDiff = "0.10"
Kronecker = "0.4"
PDMats = "0.9"
ReverseDiff = "1.2"
SpecialFunctions = "0.10"
Zygote = "0.4"
3 changes: 3 additions & 0 deletions test/basekernels/constant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
@test kappa(k,2.0) == 0.0
@test KernelFunctions.metric(ZeroKernel()) == KernelFunctions.Delta()
@test repr(k) == "Zero Kernel"
test_ADs(ZeroKernel)
end
@testset "WhiteKernel" begin
k = WhiteKernel()
Expand All @@ -14,6 +15,7 @@
@test EyeKernel == WhiteKernel
@test metric(WhiteKernel()) == KernelFunctions.Delta()
@test repr(k) == "White Kernel"
test_ADs(WhiteKernel)
end
@testset "ConstantKernel" begin
c = 2.0
Expand All @@ -24,5 +26,6 @@
@test metric(ConstantKernel()) == KernelFunctions.Delta()
@test metric(ConstantKernel(c=2.0)) == KernelFunctions.Delta()
@test repr(k) == "Constant Kernel (c = $(c))"
test_ADs(c->ConstantKernel(c=first(c)), [c])
end
end
1 change: 1 addition & 0 deletions test/basekernels/cosine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@
@test kappa(k,x) cospi(x) atol=1e-5
@test k(v1, v2) cospi(sqrt(sum(abs2.(v1-v2)))) atol=1e-5
@test repr(k) == "Cosine Kernel"
test_ADs(CosineKernel)
end
5 changes: 4 additions & 1 deletion test/basekernels/exponential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
@test SEKernel == SqExponentialKernel
@test repr(k) == "Squared Exponential Kernel"
@test KernelFunctions.iskroncompatible(k) == true
test_ADs(SEKernel)
end
@testset "ExponentialKernel" begin
k = ExponentialKernel()
Expand All @@ -24,6 +25,7 @@
@test repr(k) == "Exponential Kernel"
@test LaplacianKernel == ExponentialKernel
@test KernelFunctions.iskroncompatible(k) == true
test_ADs(ExponentialKernel)
end
@testset "GammaExponentialKernel" begin
γ = 2.0
Expand All @@ -36,7 +38,8 @@
@test metric(GammaExponentialKernel=2.0)) == SqEuclidean()
@test repr(k) == "Gamma Exponential Kernel (γ = $(γ))"
@test KernelFunctions.iskroncompatible(k) == true

test_ADs-> GammaExponentialKernel(gamma=first(γ)), [γ], ADs = [:ForwardDiff, :ReverseDiff])
@test_broken "Zygote gradient given γ"
#Coherence :
@test GammaExponentialKernel=1.0)(v1,v2) SqExponentialKernel()(v1,v2)
@test GammaExponentialKernel=0.5)(v1,v2) ExponentialKernel()(v1,v2)
Expand Down
1 change: 1 addition & 0 deletions test/basekernels/exponentiated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@
@test k(v1,v2) exp(dot(v1,v2))
@test metric(ExponentiatedKernel()) == KernelFunctions.DotProduct()
@test repr(k) == "Exponentiated Kernel"
test_ADs(ExponentiatedKernel)
end
Loading

0 comments on commit fb37557

Please sign in to comment.