Skip to content

Commit

Permalink
AD: Tag dual values correctly (#146)
Browse files Browse the repository at this point in the history
  • Loading branch information
fredrikekre authored Mar 16, 2021
1 parent 7f13296 commit 4d20f8a
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 49 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Tensors"
uuid = "48a634ad-e948-5137-8d70-aa71f2a747f4"
version = "1.4.3"
version = "1.4.4"

[deps]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand Down
97 changes: 49 additions & 48 deletions src/automatic_differentiation.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import ForwardDiff: Dual, partials, value
import ForwardDiff: Dual, partials, value, Tag

@static if isdefined(LinearAlgebra, :gradient)
import LinearAlgebra.gradient
Expand Down Expand Up @@ -193,92 +193,92 @@ end
# Loaders are supposed to take a tensor of real values and convert it
# into a tensor of dual values where the seeds are correctly defined.

@inline function _load(v::Number)
return Dual(v, one(v))
@inline function _load(v::Number, ::Tg) where Tg
return Dual{Tg}(v, one(v))
end

@inline function _load(v::Vec{1, T}) where {T}
@inbounds v_dual = Vec{1}((Dual(v[1], one(T)),))
@inline function _load(v::Vec{1, T}, ::Tg) where {T, Tg}
@inbounds v_dual = Vec{1}((Dual{Tg}(v[1], one(T)),))
return v_dual
end

@inline function _load(v::Vec{2, T}) where {T}
@inline function _load(v::Vec{2, T}, ::Tg) where {T, Tg}
o = one(T)
z = zero(T)
@inbounds v_dual = Vec{2}((Dual(v[1], o, z),
Dual(v[2], z, o)))
@inbounds v_dual = Vec{2}((Dual{Tg}(v[1], o, z),
Dual{Tg}(v[2], z, o)))
return v_dual
end

@inline function _load(v::Vec{3, T}) where {T}
@inline function _load(v::Vec{3, T}, ::Tg) where {T, Tg}
o = one(T)
z = zero(T)
@inbounds v_dual = Vec{3}((Dual(v[1], o, z, z),
Dual(v[2], z, o, z),
Dual(v[3], z, z, o)))
@inbounds v_dual = Vec{3}((Dual{Tg}(v[1], o, z, z),
Dual{Tg}(v[2], z, o, z),
Dual{Tg}(v[3], z, z, o)))
return v_dual
end

# Second order tensors
@inline function _load(v::Tensor{2, 1, T}) where {T}
@inbounds v_dual = Tensor{2, 1}((Dual(get_data(v)[1], one(T)),))
@inline function _load(v::Tensor{2, 1, T}, ::Tg) where {T, Tg}
@inbounds v_dual = Tensor{2, 1}((Dual{Tg}(get_data(v)[1], one(T)),))
return v_dual
end

@inline function _load(v::SymmetricTensor{2, 1, T}) where {T}
@inbounds v_dual = SymmetricTensor{2, 1}((Dual(get_data(v)[1], one(T)),))
@inline function _load(v::SymmetricTensor{2, 1, T}, ::Tg) where {T, Tg}
@inbounds v_dual = SymmetricTensor{2, 1}((Dual{Tg}(get_data(v)[1], one(T)),))
return v_dual
end

@inline function _load(v::Tensor{2, 2, T}) where {T}
@inline function _load(v::Tensor{2, 2, T}, ::Tg) where {T, Tg}
data = get_data(v)
o = one(T)
z = zero(T)
@inbounds v_dual = Tensor{2, 2}((Dual(data[1], o, z, z, z),
Dual(data[2], z, o, z, z),
Dual(data[3], z, z, o, z),
Dual(data[4], z, z, z, o)))
@inbounds v_dual = Tensor{2, 2}((Dual{Tg}(data[1], o, z, z, z),
Dual{Tg}(data[2], z, o, z, z),
Dual{Tg}(data[3], z, z, o, z),
Dual{Tg}(data[4], z, z, z, o)))
return v_dual
end

@inline function _load(v::SymmetricTensor{2, 2, T}) where {T}
@inline function _load(v::SymmetricTensor{2, 2, T}, ::Tg) where {T, Tg}
data = get_data(v)
o = one(T)
o2 = convert(T, 1/2)
z = zero(T)
@inbounds v_dual = SymmetricTensor{2, 2}((Dual(data[1], o, z, z),
Dual(data[2], z, o2, z),
Dual(data[3], z, z, o)))
@inbounds v_dual = SymmetricTensor{2, 2}((Dual{Tg}(data[1], o, z, z),
Dual{Tg}(data[2], z, o2, z),
Dual{Tg}(data[3], z, z, o)))
return v_dual
end

@inline function _load(v::Tensor{2, 3, T}) where {T}
@inline function _load(v::Tensor{2, 3, T}, ::Tg) where {T, Tg}
data = get_data(v)
o = one(T)
z = zero(T)
@inbounds v_dual = Tensor{2, 3}((Dual(data[1], o, z, z, z, z, z, z, z, z),
Dual(data[2], z, o, z, z, z, z, z, z, z),
Dual(data[3], z, z, o, z, z, z, z, z, z),
Dual(data[4], z, z, z, o, z, z, z, z, z),
Dual(data[5], z, z, z, z, o, z, z, z, z),
Dual(data[6], z, z, z, z, z, o, z, z, z),
Dual(data[7], z, z, z, z, z, z, o, z, z),
Dual(data[8], z, z, z, z, z, z, z, o, z),
Dual(data[9], z, z, z, z, z, z, z, z, o)))
@inbounds v_dual = Tensor{2, 3}((Dual{Tg}(data[1], o, z, z, z, z, z, z, z, z),
Dual{Tg}(data[2], z, o, z, z, z, z, z, z, z),
Dual{Tg}(data[3], z, z, o, z, z, z, z, z, z),
Dual{Tg}(data[4], z, z, z, o, z, z, z, z, z),
Dual{Tg}(data[5], z, z, z, z, o, z, z, z, z),
Dual{Tg}(data[6], z, z, z, z, z, o, z, z, z),
Dual{Tg}(data[7], z, z, z, z, z, z, o, z, z),
Dual{Tg}(data[8], z, z, z, z, z, z, z, o, z),
Dual{Tg}(data[9], z, z, z, z, z, z, z, z, o)))
return v_dual
end

@inline function _load(v::SymmetricTensor{2, 3, T}) where {T}
@inline function _load(v::SymmetricTensor{2, 3, T}, ::Tg) where {T, Tg}
data = get_data(v)
o = one(T)
o2 = convert(T, 1/2)
z = zero(T)
@inbounds v_dual = SymmetricTensor{2, 3}((Dual(data[1], o, z, z, z, z, z),
Dual(data[2], z, o2, z, z, z, z),
Dual(data[3], z, z, o2, z, z, z),
Dual(data[4], z, z, z, o, z, z),
Dual(data[5], z, z, z, z, o2, z),
Dual(data[6], z, z, z, z, z, o)))
@inbounds v_dual = SymmetricTensor{2, 3}((Dual{Tg}(data[1], o, z, z, z, z, z),
Dual{Tg}(data[2], z, o2, z, z, z, z),
Dual{Tg}(data[3], z, z, o2, z, z, z),
Dual{Tg}(data[4], z, z, z, o, z, z),
Dual{Tg}(data[5], z, z, z, z, o2, z),
Dual{Tg}(data[6], z, z, z, z, z, o)))
return v_dual
end

Expand All @@ -301,13 +301,13 @@ julia> ∇f = gradient(norm, A)
julia> ∇f, f = gradient(norm, A, :all);
```
"""
function gradient(f::F, v::Union{SecondOrderTensor, Vec, Number}) where {F}
v_dual = _load(v)
function gradient(f::F, v::V) where {F, V <: Union{SecondOrderTensor, Vec, Number}}
v_dual = _load(v, Tag(f, V))
res = f(v_dual)
return _extract_gradient(res, v)
end
function gradient(f::F, v::Union{SecondOrderTensor, Vec, Number}, ::Symbol) where {F}
v_dual = _load(v)
function gradient(f::F, v::V, ::Symbol) where {F, V <: Union{SecondOrderTensor, Vec, Number}}
v_dual = _load(v, Tag(f, V))
res = f(v_dual)
return _extract_gradient(res, v), _extract_value(res)
end
Expand Down Expand Up @@ -431,9 +431,10 @@ function laplace(f::F, v) where F
end
const Δ = laplace

function Broadcast.broadcasted(::typeof(laplace), f::F, v::Vec{3}) where {F}
function Broadcast.broadcasted(::typeof(laplace), f::F, v::V) where {F, V <: Vec{3}}
@inbounds begin
vdd = _load(_load(v))
tag = Tag(f, V)
vdd = _load(_load(v, tag), tag)
res = f(vdd)
v1 = res[1].partials[1].partials[1] + res[1].partials[2].partials[2] + res[1].partials[3].partials[3]
v2 = res[2].partials[1].partials[1] + res[2].partials[2].partials[2] + res[2].partials[3].partials[3]
Expand Down
9 changes: 9 additions & 0 deletions test/test_ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,4 +178,13 @@ S(C) = S(C, μ, Kb)
@test hessian(f, x, :all)[3] f(x) x^3 * S
end
end
@testsection "mixed scalar/vec" begin
f(v::Vec, s::Number) = s * v[1] * v[2]
v = Vec((2.0, 3.0))
s = 4.0
@test gradient(x -> f(v, x), s)::Float64 v[1] * v[2]
@test gradient(x -> f(x, s), v)::Vec{2} Vec((s * v[2], s * v[1]))
@test gradient(y -> gradient(x -> f(y, x), s), v)::Vec{2} Vec((v[2], v[1]))
@test gradient(y -> gradient(x -> f(x, y), v), s)::Vec{2} Vec((v[2], v[1]))
end
end # testsection

3 comments on commit 4d20f8a

@fredrikekre
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fredrikekre
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/32091

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.4.4 -m "<description of version>" 4d20f8a6e989ff34e1eeb2022207da9096ddd23b
git push origin v1.4.4

Please sign in to comment.