Skip to content
This repository has been archived by the owner on Mar 12, 2021. It is now read-only.

Gamma family function support #321

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ CUDAdrv = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde"
CUDAnative = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Expand Down
3 changes: 2 additions & 1 deletion src/CuArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using GPUArrays

export CuArray, CuVector, CuMatrix, CuVecOrMat, cu

import LinearAlgebra
import LinearAlgebra, SpecialFunctions

using Adapt

Expand All @@ -31,6 +31,7 @@ include("array.jl")
include("subarray.jl")
include("utils.jl")
include("indexing.jl")
include("special/gamma.jl")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just make it special.jl, no need for directories with single source files.

include("broadcast.jl")
include("matmul.jl")
include("mapreduce.jl")
Expand Down
5 changes: 5 additions & 0 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ for f in libdevice
@eval cufunc(::typeof(Base.$f)) = CUDAnative.$f
end

cufunc(::typeof(SpecialFunctions.lbeta)) = CuArrays.lbeta
cufunc(::typeof(SpecialFunctions.lgamma)) = CuArrays.lgamma
cufunc(::typeof(SpecialFunctions.digamma)) = CuArrays.digamma
cufunc(::typeof(SpecialFunctions.trigamma)) = CuArrays.trigamma

#broadcast ^
culiteral_pow(::typeof(^), x::Union{Float32, Float64}, ::Val{0}) = one(x)
culiteral_pow(::typeof(^), x::Union{Float32, Float64}, ::Val{1}) = x
Expand Down
15 changes: 13 additions & 2 deletions src/forwarddiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,20 @@ ForwardDiff.DiffRules.DEFINED_DIFFRULES[(:CUDAnative, :abs, 1)] = x ->
:(signbit(x) ? -one(x) : one(x))
eval(ForwardDiff.unary_dual_definition(:CUDAnative, :abs))

# byhand: lgamma
ForwardDiff.DiffRules.@define_diffrule CuArrays.lgamma(a) = :(CuArrays.digamma($a))
eval(ForwardDiff.unary_dual_definition(:CuArrays, :lgamma))

ForwardDiff.DiffRules.DEFINED_DIFFRULES[(:CUDAnative, :pow, 2)] = (x, y) ->
replace_device.(ForwardDiff.DiffRules.DEFINED_DIFFRULES[(:Base, :^, 2)](x, y))
# byhand: digamma
ForwardDiff.DiffRules.@define_diffrule CuArrays.digamma(a) = :(CuArrays.trigamma($a))
eval(ForwardDiff.unary_dual_definition(:CuArrays, :digamma))

# byhand: lbeta
ForwardDiff.DiffRules.@define_diffrule CuArrays.lbeta(a, b) = :(CuArrays.digamma($a) - CuArrays.digamma($a + $b)), :(CuArrays.digamma($b) - CuArrays.digamma($a + $b))
eval(ForwardDiff.binary_dual_definition(:CuArrays, :lbeta))

ForwardDiff.DiffRules.DEFINED_DIFFRULES[(:CUDAnative, :pow, 2)] =
(x, y) -> replace_device.(ForwardDiff.DiffRules.DEFINED_DIFFRULES[(:Base, :^, 2)](x, y))

@eval begin
ForwardDiff.@define_binary_dual_op(
Expand Down
63 changes: 63 additions & 0 deletions src/special/gamma.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# This file is heavlily adopted from https://github.com/JuliaMath/SpecialFunctions.jl.
# License is MIT: http://julialang.org/license

function lgamma(x)
return CUDAnative.lgamma(x)
end

function digamma(x)
if x <= 0 # reflection formula
ψ = -π / CUDAnative.tan(π * x)
x = 1 - x
else
ψ = zero(x)
end
if x < 7
# shift using recurrence formula
ν = one(x)
n = 7 - CUDAnative.floor(x)
while ν <= n - 1
ψ -= inv(x + ν)
ν += one(x)
end
ψ -= inv(x)
x += n
end
t = inv(x)
ψ += CUDAnative.log(x) - 0.5 * t
t *= t # 1/z^2
# the coefficients here are Float64(bernoulli[2:9] .// (2*(1:8)))
ψ -= t * @evalpoly(t,0.08333333333333333,-0.008333333333333333,0.003968253968253968,-0.004166666666666667,0.007575757575757576,-0.021092796092796094,0.08333333333333333,-0.4432598039215686)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, this will convert into Float64; maybe it ought to be dependent on the input or use Float32 by default?

return ψ
end

function _trigamma(x)
ψ = zero(x)
if x < 8
# shift using recurrence formula
n = 8 - CUDAnative.floor(x)
ψ += inv(x)^2
ν = one(x)
while ν <= n - 1
ψ += inv(x + ν)^2
ν += one(x)
end
x += n
end
t = inv(x)
w = t * t # 1/z^2
ψ += t + 0.5 * w
# the coefficients here are Float64(bernoulli[2:9])
ψ += t * w * @evalpoly(w,0.16666666666666666,-0.03333333333333333,0.023809523809523808,-0.03333333333333333,0.07575757575757576,-0.2531135531135531,1.1666666666666667,-7.092156862745098)
return ψ
end

function trigamma(x)
if x <= 0 # reflection formula
return (π / CUDAnative.sin(π * x))^2 - _trigamma(1 - x)
else
return _trigamma(x)
end
end

lbeta(x, y) = lgamma(x) + lgamma(y) - lgamma(x + y)
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ include("fft.jl")
include("sparse.jl")
include("solver.jl")
include("sparse_solver.jl")
include("special.jl")
include("dnn.jl")
include("forwarddiff.jl")

Expand Down
96 changes: 96 additions & 0 deletions test/special.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
using Test
import SpecialFunctions
using Flux: Tracker
using CuArrays

n = 1000

xs_lgamma = randn(Float32, n); xs_lgamma_cu = cu(xs_lgamma)
xs_digamma = randn(Float32, n); xs_digamma_cu = cu(xs_digamma)
xs_trigamma = randn(Float32, n); xs_trigamma_cu = cu(xs_trigamma)
xs_lbeta_tuple = (randn(Float32, n), randn(Float32, n))
xs_lbeta_tuple = map(xs -> abs.(xs), xs_lbeta_tuple); xs_lbeta_cu_tuple = map(cu, xs_lbeta_tuple)

catgrads(grads) = cat(map(ta -> ta.data, grads)...; dims=1)
g∑fx(f, xs) = catgrads(Tracker.gradient(_xs -> sum(f.(_xs)), xs))
g∑fx(f, xs, ys) = catgrads(Tracker.gradient((_xs, _ys) -> sum(f.(_xs, _ys)), xs, ys))

results = Dict()
@testset "Forward evaluation" begin
fn = :lgamma
@testset "$fn" begin
lgamma_val_cpu = @time SpecialFunctions.lgamma.(xs_lgamma)
lgamma_val_gpu = @time CuArrays.lgamma.(xs_lgamma_cu)
lgamma_val_gpu = Array(lgamma_val_gpu)
for i = 1:n
@test lgamma_val_cpu[i] ≈ lgamma_val_gpu[i]
end
results[fn] = (lgamma_val_cpu, lgamma_val_gpu)
end

fn = :digamma
@testset "$fn" begin
digamma_val_cpu = @time SpecialFunctions.digamma.(xs_digamma)
digamma_val_gpu = @time CuArrays.digamma.(xs_digamma_cu)
digamma_val_gpu = Array(digamma_val_gpu)
for i = 1:n
@test digamma_val_cpu[i] ≈ digamma_val_gpu[i]
end
results[fn] = (digamma_val_cpu, digamma_val_gpu)
end

fn = :trigamma
@testset "$fn" begin
trigamma_val_cpu = @time SpecialFunctions.trigamma.(xs_trigamma)
trigamma_val_gpu = @time CuArrays.trigamma.(xs_trigamma_cu)
trigamma_val_gpu = Array(trigamma_val_gpu)
for i = 1:n
@test trigamma_val_cpu[i] ≈ trigamma_val_gpu[i]
end
results[fn] = (trigamma_val_cpu, trigamma_val_gpu)
end

fn = :lbeta
@testset "$fn" begin
lbeta_val_cpu = @time SpecialFunctions.lbeta.(xs_lbeta_tuple...)
lbeta_val_gpu = @time CuArrays.lbeta.(xs_lbeta_cu_tuple...)
lbeta_val_gpu = Array(lbeta_val_gpu)
for i = 1:n
@test lbeta_val_cpu[i] ≈ lbeta_val_gpu[i]
end
results[fn] = (lbeta_val_cpu, lbeta_val_gpu)
end

end

@testset "Gradient evaluation" begin
fn = :lgamma
@testset "$fn" begin
lgamma_grad_cpu = @time g∑fx(SpecialFunctions.lgamma, xs_lgamma)
lgamma_grad_gpu = @time g∑fx(CuArrays.lgamma, xs_lgamma_cu)
lgamma_grad_gpu = Array(lgamma_grad_gpu)
for i = 1:n
@test lgamma_grad_cpu[i] ≈ lgamma_grad_gpu[i]
end
end

fn = :digamma
@testset "$fn" begin
digamma_grad_cpu = @time g∑fx(SpecialFunctions.digamma, xs_digamma)
digamma_grad_gpu = @time g∑fx(CuArrays.digamma, xs_digamma_cu)
digamma_grad_gpu = Array(digamma_grad_gpu)
for i = 1:n
@test digamma_grad_cpu[i] ≈ digamma_grad_gpu[i]
end
end

fn = :lbeta
@testset "$fn" begin
lbeta_grad_cpu = @time g∑fx(SpecialFunctions.lbeta, xs_lbeta_tuple...)
lbeta_grad_gpu = @time g∑fx(CuArrays.lbeta, xs_lbeta_cu_tuple...)
lbeta_grad_gpu = Array(lbeta_grad_gpu)
for i = 1:n
@test lbeta_grad_cpu[i] ≈ lbeta_grad_gpu[i]
end
end
end