Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Series of tests for AD #114

Merged
merged 34 commits into from
Jun 15, 2020
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
2234d4f
Solved tests for Matern
theogf May 12, 2020
a6159e1
Solved tests for dotproduct
theogf May 12, 2020
4aeb0e3
First draft of AD tests
theogf May 12, 2020
b6a7901
Removing unnecessary functions and uncommented all cases
theogf May 12, 2020
f70adc1
Created two function for testing any kernel, any AD and compare with …
theogf May 14, 2020
2ae0cd6
Needed type promotion for ForwardDiff.jl
theogf May 14, 2020
d88dcff
Created indirection from Base.map to _map for creating adjoints
theogf May 14, 2020
6875aee
Created full adjoints for DotProduct and evaluate for Sinus
theogf May 14, 2020
44368fb
Fixing ambiguity for Identity transform
theogf May 14, 2020
b3142f6
Adding test dependencies for AD
theogf May 15, 2020
44ad0cd
Put everything under _map to avoid ambiguities
theogf May 15, 2020
07631b6
Tests passing for constant kernels/modified Zygote to return zeros in…
theogf May 15, 2020
960bad2
Spread tests for all base kernels
theogf May 16, 2020
3e620ae
Spread tests for all transforms
theogf May 16, 2020
24cb00d
Removed need to give a name
theogf May 16, 2020
5b2e580
Adding needed export
theogf May 16, 2020
0bba1a5
Removed test_AD
theogf May 16, 2020
7f52242
Readded all tests
theogf May 16, 2020
f1000b3
Fixed tests and added adjoint tests
theogf May 16, 2020
4023365
Fixing issues in tests and adding some output to avoid travis-ci prob…
theogf May 17, 2020
a73133b
Relaxed tolerance
theogf May 17, 2020
d586967
Added atol for test (for comparisons around 0)
theogf May 22, 2020
577518f
Rewrote testing code
theogf May 25, 2020
9d82e1c
Put a seed for FBM tests
theogf May 25, 2020
181341e
Remove adjoint for Sinus
theogf May 25, 2020
88c6af7
Import all Flux functions
theogf May 25, 2020
aa282a1
Fix parenthesis missing
theogf May 25, 2020
ffefd1f
Fixed some stupid testing bugs
theogf May 25, 2020
6b5ba4d
Corrected Tests Zygote Adjoints
theogf May 25, 2020
b6ddf52
Clearer failing messages
theogf May 25, 2020
5c7eb6a
Adding Project.toml to test folder
theogf Jun 12, 2020
a4e5bb2
Missing ForwardDiff and removed KernelFunctions
theogf Jun 12, 2020
686ad8c
Missing ReverseDiff
theogf Jun 12, 2020
e94973e
Removed passing tests for PeriodicKernel
theogf Jun 15, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/CompatHelper.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ jobs:
- name: CompatHelper.main()
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: julia -e 'using CompatHelper; CompatHelper.main()'
run: julia -e 'using CompatHelper; CompatHelper.main(; subdirs = ["", "test"])'
12 changes: 0 additions & 12 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,3 @@ StatsBase = "0.32, 0.33"
StatsFuns = "0.8, 0.9"
ZygoteRules = "0.2"
julia = "1.3"

[extras]
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Random", "Test", "FiniteDifferences", "Zygote", "PDMats", "Kronecker", "Flux"]
4 changes: 2 additions & 2 deletions src/KernelFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ export spectral_mixture_kernel, spectral_mixture_product_kernel
using Compat
using Requires
using Distances, LinearAlgebra
using SpecialFunctions: logabsgamma, besselk
using ZygoteRules: @adjoint
using SpecialFunctions: logabsgamma, besselk, polygamma
using ZygoteRules: @adjoint, pullback
using StatsFuns: logtwo
using InteractiveUtils: subtypes
using StatsBase
Expand Down
11 changes: 5 additions & 6 deletions src/basekernels/matern.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,11 @@ end

@inline function kappa(κ::MaternKernel, d::Real)
ν = first(κ.ν)
iszero(d) ? one(d) :
exp(
(one(d) - ν) * logtwo - logabsgamma(ν)[1] +
ν * log(sqrt(2ν) * d) +
log(besselk(ν, sqrt(2ν) * d))
)
iszero(d) ? one(d) : _matern(ν, d)
end

function _matern(ν::Real, d::Real)
exp((one(d) - ν) * logtwo - loggamma(ν) + ν * log(sqrt(2ν) * d) + log(besselk(ν, sqrt(2ν) * d)))
end

metric(::MaternKernel) = Euclidean()
Expand Down
6 changes: 4 additions & 2 deletions src/distances/delta.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
struct Delta <: Distances.PreMetric
end

@inline function Distances._evaluate(::Delta,a::AbstractVector{T},b::AbstractVector{T}) where {T}
@inline function Distances._evaluate(::Delta, a::AbstractVector, b::AbstractVector) where {T}
@boundscheck if length(a) != length(b)
throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b))."))
end
return a == b
end

Distances.result_type(::Delta, Ta::Type, Tb::Type) = promote_type(Ta, Tb)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Distances do this for determining the type of the allocated matrix in pairwise!, if this is not defined it defaults to Float64 and therefore breaks AD for ForwardDiff and others :
https://github.com/JuliaStats/Distances.jl/blob/f69f7888c92458ae671c893d079ecf5fc8d8accd/src/generic.jl#L35
https://github.com/JuliaStats/Distances.jl/blob/f69f7888c92458ae671c893d079ecf5fc8d8accd/src/generic.jl#L203

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, we don't want to restrict it to Float64 only. My comment was just that Julia actually won't compile specialized versions of this function for different input types since you used Ta::Type instead of ::Type{Ta}.


@inline (dist::Delta)(a::AbstractArray, b::AbstractArray) = Distances._evaluate(dist, a, b)
@inline (dist::Delta)(a::Number,b::Number) = a == b
@inline (dist::Delta)(a::Number, b::Number) = a == b
4 changes: 3 additions & 1 deletion src/distances/dotproduct.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
struct DotProduct <: Distances.PreMetric end
# struct DotProduct <: Distances.UnionSemiMetric end

@inline function Distances._evaluate(::DotProduct, a::AbstractVector{T}, b::AbstractVector{T}) where {T}
@inline function Distances._evaluate(::DotProduct, a::AbstractVector, b::AbstractVector)
@boundscheck if length(a) != length(b)
throw(DimensionMismatch("first array has length $(length(a)) which does not match the length of the second, $(length(b))."))
end
return dot(a,b)
end

Distances.result_type(::DotProduct, Ta::Type, Tb::Type) = promote_type(Ta, Tb)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.


@inline Distances.eval_op(::DotProduct, a::Real, b::Real) = a * b
@inline (dist::DotProduct)(a::AbstractArray,b::AbstractArray) = Distances._evaluate(dist, a, b)
@inline (dist::DotProduct)(a::Number,b::Number) = a * b
4 changes: 3 additions & 1 deletion src/distances/sinus.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ Distances.parameters(d::Sinus) = d.r
@inline (dist::Sinus)(a::AbstractArray, b::AbstractArray) = Distances._evaluate(dist, a, b)
@inline (dist::Sinus)(a::Number, b::Number) = abs2(sinpi(a - b) / first(dist.r))

@inline function Distances._evaluate(d::Sinus, a::AbstractVector{T}, b::AbstractVector{T}) where {T}
Distances.result_type(::Sinus{T}, Ta::Type, Tb::Type) where {T} = promote_type(T, Ta, Tb)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And here.


@inline function Distances._evaluate(d::Sinus, a::AbstractVector, b::AbstractVector) where {T}
@boundscheck if (length(a) != length(b)) || length(a) != length(d.r)
throw(DimensionMismatch("Dimensions of the inputs are not matching : a = $(length(a)), b = $(length(b)), r = $(length(d.r))"))
end
Expand Down
6 changes: 3 additions & 3 deletions src/transform/ardtransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ dim(t::ARDTransform) = length(t.v)
(t::ARDTransform)(x::Real) = first(t.v) * x
(t::ARDTransform)(x) = t.v .* x

Base.map(t::ARDTransform, x::AbstractVector{<:Real}) = t.v' .* x
Base.map(t::ARDTransform, x::ColVecs) = ColVecs(t.v .* x.X)
Base.map(t::ARDTransform, x::RowVecs) = RowVecs(t.v' .* x.X)
_map(t::ARDTransform, x::AbstractVector{<:Real}) = t.v' .* x
_map(t::ARDTransform, x::ColVecs) = ColVecs(t.v .* x.X)
_map(t::ARDTransform, x::RowVecs) = RowVecs(t.v' .* x.X)

Base.isequal(t::ARDTransform, t2::ARDTransform) = isequal(t.v, t2.v)

Expand Down
2 changes: 1 addition & 1 deletion src/transform/chaintransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Base.:∘(tc::ChainTransform, t::Transform) = ChainTransform(vcat(t, tc.transfor

(t::ChainTransform)(x) = foldl((x, t) -> t(x), t.transforms; init=x)

function Base.map(t::ChainTransform, x::AbstractVector)
function _map(t::ChainTransform, x::AbstractVector)
return foldl((x, t) -> map(t, x), t.transforms; init=x)
end

Expand Down
6 changes: 3 additions & 3 deletions src/transform/functiontransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ end

(t::FunctionTransform)(x) = t.f(x)

Base.map(t::FunctionTransform, x::AbstractVector{<:Real}) = map(t.f, x)
Base.map(t::FunctionTransform, x::ColVecs) = ColVecs(mapslices(t.f, x.X; dims=1))
Base.map(t::FunctionTransform, x::RowVecs) = RowVecs(mapslices(t.f, x.X; dims=2))
_map(t::FunctionTransform, x::AbstractVector{<:Real}) = map(t.f, x)
_map(t::FunctionTransform, x::ColVecs) = ColVecs(mapslices(t.f, x.X; dims=1))
_map(t::FunctionTransform, x::RowVecs) = RowVecs(mapslices(t.f, x.X; dims=2))

duplicate(t::FunctionTransform,f) = FunctionTransform(f)

Expand Down
6 changes: 3 additions & 3 deletions src/transform/lineartransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ end
(t::LinearTransform)(x::Real) = vec(t.A * x)
(t::LinearTransform)(x::AbstractVector{<:Real}) = t.A * x

Base.map(t::LinearTransform, x::AbstractVector{<:Real}) = ColVecs(t.A * x')
Base.map(t::LinearTransform, x::ColVecs) = ColVecs(t.A * x.X)
Base.map(t::LinearTransform, x::RowVecs) = RowVecs(x.X * t.A')
_map(t::LinearTransform, x::AbstractVector{<:Real}) = ColVecs(t.A * x')
_map(t::LinearTransform, x::ColVecs) = ColVecs(t.A * x.X)
_map(t::LinearTransform, x::RowVecs) = RowVecs(x.X * t.A')

function Base.show(io::IO, t::LinearTransform)
print(io::IO, "Linear transform (size(A) = ", size(t.A), ")")
Expand Down
6 changes: 3 additions & 3 deletions src/transform/scaletransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ set!(t::ScaleTransform,ρ::Real) = t.s .= [ρ]

(t::ScaleTransform)(x) = first(t.s) .* x

Base.map(t::ScaleTransform, x::AbstractVector{<:Real}) = first(t.s) .* x
Base.map(t::ScaleTransform, x::ColVecs) = ColVecs(first(t.s) .* x.X)
Base.map(t::ScaleTransform, x::RowVecs) = RowVecs(first(t.s) .* x.X)
_map(t::ScaleTransform, x::AbstractVector{<:Real}) = first(t.s) .* x
_map(t::ScaleTransform, x::ColVecs) = ColVecs(first(t.s) .* x.X)
_map(t::ScaleTransform, x::RowVecs) = RowVecs(first(t.s) .* x.X)

Base.isequal(t::ScaleTransform,t2::ScaleTransform) = isequal(first(t.s),first(t2.s))

Expand Down
4 changes: 2 additions & 2 deletions src/transform/selecttransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ duplicate(t::SelectTransform,θ) = t

(t::SelectTransform)(x::AbstractVector) = view(x, t.select)

Base.map(t::SelectTransform, x::ColVecs) = ColVecs(view(x.X, t.select, :))
Base.map(t::SelectTransform, x::RowVecs) = RowVecs(view(x.X, :, t.select))
_map(t::SelectTransform, x::ColVecs) = ColVecs(view(x.X, t.select, :))
_map(t::SelectTransform, x::RowVecs) = RowVecs(view(x.X, :, t.select))

Base.show(io::IO, t::SelectTransform) = print(io, "Select Transform (dims: ", t.select, ")")
8 changes: 2 additions & 6 deletions src/transform/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,8 @@ include("functiontransform.jl")
include("selecttransform.jl")
include("chaintransform.jl")

"""
apply(t::Transform, x; obsdim::Int=defaultobs)

Apply the transform `t` vector-wise on the array `x`
"""
apply
Base.map(t::Transform, x::AbstractVector) = _map(t, x)

"""
IdentityTransform()
Expand All @@ -20,7 +16,7 @@ Return exactly the input
struct IdentityTransform <: Transform end

(t::IdentityTransform)(x) = x
Base.map(::IdentityTransform, x::AbstractVector) = x
_map(::IdentityTransform, x::AbstractVector) = x

### TODO Maybe defining adjoints could help but so far it's not working

Expand Down
3 changes: 2 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
hadamard(x, y) = x .* y

loggamma(x) = first(logabsgamma(x))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no reason to define loggamma other than avoiding the type piracy when defining an adjoint for logabsgamma, I guess?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that really type piracy? logabsgamma return a value and a sign, but we are only interested in the value so a wrapper solves this problems and allows to create a sensible adjoint.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, loggamma is no type piracy. I was just wondering why you define the function at all and not just define a custom adjoint for logabsgamma that outputs nothing for the second argument - and my assumption was that you didn't define one because you wanted to avoid this type piracy (of defining an adjoint for a function that we don't own).


# Macro for checking arguments
macro check_args(K, param, cond, desc=string(cond))
quote
Expand Down Expand Up @@ -124,4 +126,3 @@ function validate_dims(x::AbstractVector, y::AbstractVector)
))
end
end

83 changes: 76 additions & 7 deletions src/zygote_adjoints.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,78 @@
## Adjoints Delta
@adjoint function evaluate(s::Delta, x::AbstractVector, y::AbstractVector)
evaluate(s, x, y), Δ -> begin
(nothing, nothing, nothing)
end
end

@adjoint function pairwise(d::Delta, X::AbstractMatrix, Y::AbstractMatrix; dims=2)
D = pairwise(d, X, Y; dims = dims)
if dims == 1
return D, Δ -> (nothing, nothing, nothing)
else
return D, Δ -> (nothing, nothing, nothing)
end
end

@adjoint function pairwise(d::Delta, X::AbstractMatrix; dims=2)
D = pairwise(d, X; dims = dims)
if dims == 1
return D, Δ -> (nothing, nothing)
else
return D, Δ -> (nothing, nothing)
end
end

## Adjoints DotProduct
@adjoint function evaluate(s::DotProduct, x::AbstractVector, y::AbstractVector)
dot(x, y), Δ -> begin
(nothing, Δ .* y, Δ .* x)
end
end

@adjoint function pairwise(d::DotProduct, X::AbstractMatrix, Y::AbstractMatrix; dims=2)
D = pairwise(d, X, Y; dims = dims)
if dims == 1
return D, Δ -> (nothing, Δ * Y, (X' * Δ)')
else
return D, Δ -> (nothing, (Δ * Y')', X * Δ)
end
end

@adjoint function pairwise(d::DotProduct, X::AbstractMatrix; dims=2)
D = pairwise(d, X; dims = dims)
if dims == 1
return D, Δ -> (nothing, 2 * Δ * X)
else
return D, Δ -> (nothing, 2 * X * Δ)
end
end

## Adjoints Sinus
@adjoint function evaluate(s::Sinus, x::AbstractVector, y::AbstractVector)
d = (x - y)
sind = sinpi.(d)
val = sum(abs2, sind ./ s.r)
gradx = 2π .* cospi.(d) .* sind ./ (s.r .^ 2)
val, Δ -> begin
((r = -2Δ .* abs2.(sind) ./ s.r,), Δ * gradx, - Δ * gradx)
end
end

@adjoint function loggamma(x)
first(logabsgamma(x)) , Δ -> (Δ .* polygamma(0, x), )
end

@adjoint function kappa(κ::MaternKernel, d::Real)
ν = first(κ.ν)
val, grad = pullback(_matern, ν, d)
return ((iszero(d) ? one(d) : val),
Δ -> begin
∇ = grad(Δ)
return ((ν = [∇[1]],), iszero(d) ? zero(d) : ∇[2])
end)
end
Comment on lines +66 to +74
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed? The definition of kappa(::MaternKernel, ::Real) seems quite simple, so I would have assumed that AD just works?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is needed for the edge case where x = 0, matern is defined as a iszero(d) ? one(d) : _matern(d) and Zygote will just return nothing on this branch kind of thing. It breaks the backpropagation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does ifelse(iszero(d), one(d), _matern(d)) fix the Zygote issue? I've seen this pattern quite a bit in the Zygote source code, so maybe it's helpful in such cases?


@adjoint function ColVecs(X::AbstractMatrix)
back(Δ::NamedTuple) = (Δ.X,)
back(Δ::AbstractMatrix) = (Δ,)
Expand All @@ -22,10 +91,10 @@ end
return RowVecs(X), back
end

# @adjoint function evaluate(s::Sinus, x::AbstractVector, y::AbstractVector)
# d = evaluate(s, x, y)
# s = sum(sin.(π*(x-y)))
# d, Δ -> begin
# (Sinus(Δ ./ s.r), 2Δ .* cos.(x - y) * d, -2Δ .* cos.(x - y) * d)
# end
# end
@adjoint function Base.map(t::Transform, X::ColVecs)
pullback(_map, t, X)
end

@adjoint function Base.map(t::Transform, X::RowVecs)
pullback(_map, t, X)
end
21 changes: 21 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
[deps]
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
devmotion marked this conversation as resolved.
Show resolved Hide resolved
Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Distances = "0.9"
FiniteDifferences = "0.10"
Flux = "0.10"
Kronecker = "0.4"
PDMats = "0.9"
SpecialFunctions = "0.10"
Zygote = "0.4"
3 changes: 3 additions & 0 deletions test/basekernels/constant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
@test kappa(k,2.0) == 0.0
@test KernelFunctions.metric(ZeroKernel()) == KernelFunctions.Delta()
@test repr(k) == "Zero Kernel"
test_ADs(ZeroKernel)
end
@testset "WhiteKernel" begin
k = WhiteKernel()
Expand All @@ -14,6 +15,7 @@
@test EyeKernel == WhiteKernel
@test metric(WhiteKernel()) == KernelFunctions.Delta()
@test repr(k) == "White Kernel"
test_ADs(WhiteKernel)
end
@testset "ConstantKernel" begin
c = 2.0
Expand All @@ -24,5 +26,6 @@
@test metric(ConstantKernel()) == KernelFunctions.Delta()
@test metric(ConstantKernel(c=2.0)) == KernelFunctions.Delta()
@test repr(k) == "Constant Kernel (c = $(c))"
test_ADs(c->ConstantKernel(c=first(c)), [c])
end
end
1 change: 1 addition & 0 deletions test/basekernels/cosine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@
@test kappa(k,x) ≈ cospi(x) atol=1e-5
@test k(v1, v2) ≈ cospi(sqrt(sum(abs2.(v1-v2)))) atol=1e-5
@test repr(k) == "Cosine Kernel"
test_ADs(CosineKernel)
end
5 changes: 4 additions & 1 deletion test/basekernels/exponential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
@test SEKernel == SqExponentialKernel
@test repr(k) == "Squared Exponential Kernel"
@test KernelFunctions.iskroncompatible(k) == true
test_ADs(SEKernel)
end
@testset "ExponentialKernel" begin
k = ExponentialKernel()
Expand All @@ -24,6 +25,7 @@
@test repr(k) == "Exponential Kernel"
@test LaplacianKernel == ExponentialKernel
@test KernelFunctions.iskroncompatible(k) == true
test_ADs(ExponentialKernel)
end
@testset "GammaExponentialKernel" begin
γ = 2.0
Expand All @@ -36,7 +38,8 @@
@test metric(GammaExponentialKernel(γ=2.0)) == SqEuclidean()
@test repr(k) == "Gamma Exponential Kernel (γ = $(γ))"
@test KernelFunctions.iskroncompatible(k) == true

test_ADs(γ -> GammaExponentialKernel(gamma=first(γ)), [γ], ADs = [:ForwardDiff, :ReverseDiff])
@test_broken "Zygote gradient given γ"
#Coherence :
@test GammaExponentialKernel(γ=1.0)(v1,v2) ≈ SqExponentialKernel()(v1,v2)
@test GammaExponentialKernel(γ=0.5)(v1,v2) ≈ ExponentialKernel()(v1,v2)
Expand Down
1 change: 1 addition & 0 deletions test/basekernels/exponentiated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@
@test k(v1,v2) ≈ exp(dot(v1,v2))
@test metric(ExponentiatedKernel()) == KernelFunctions.DotProduct()
@test repr(k) == "Exponentiated Kernel"
test_ADs(ExponentiatedKernel)
end
Loading