Skip to content

Commit

Permalink
Merge pull request #73 from SciML/trackedreal
Browse files Browse the repository at this point in the history
Branch for handling ReverseDiff TrackedReals in SDEs
  • Loading branch information
ChrisRackauckas authored Nov 21, 2020
2 parents 248ccdf + ff76d5b commit fb527d8
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/init.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ function __init__()
@require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
wiener_randn!(rng::AbstractRNG,rand_vec::CuArrays.CuArray) = randn!(rand_vec)
end

@require CUDA="052768ef-5323-5732-b1bb-66c8b64840ba" begin
wiener_randn!(rng::AbstractRNG,rand_vec::CUDA.CuArray) = randn!(rand_vec)
end
Expand All @@ -13,5 +13,8 @@ function __init__()
@inline function DiffEqNoiseProcess.wiener_randn(rng::Random.AbstractRNG,proto::ReverseDiff.TrackedArray)
ReverseDiff.track(convert.(eltype(proto.value),randn(rng,size(proto))))
end
@inline function DiffEqNoiseProcess.wiener_randn(rng::Random.AbstractRNG,proto::AbstractArray{<:ReverseDiff.TrackedReal})
ReverseDiff.track.(randn.(rng,eltype(DiffEqBase.value.(proto))))
end
end
end

0 comments on commit fb527d8

Please sign in to comment.