From 9c3098ebdceb9e75414518258fed923307a959d5 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sat, 21 Nov 2020 02:00:12 -0500 Subject: [PATCH 1/2] Branch for handling ReverseDiff TrackedReals in SDEs --- src/init.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/init.jl b/src/init.jl index ea41009..722c6e1 100644 --- a/src/init.jl +++ b/src/init.jl @@ -13,5 +13,8 @@ function __init__() @inline function DiffEqNoiseProcess.wiener_randn(rng::Random.AbstractRNG,proto::ReverseDiff.TrackedArray) ReverseDiff.track(convert.(eltype(proto.value),randn(rng,size(proto)))) end + @inline function DiffEqNoiseProcess.wiener_randn(rng::Random.AbstractRNG,proto::AbstractArray{<:TrackedReal}) + ReverseDiff.track.(randn.(rng,eltype(DiffEqBase.value.(proto)))) + end end end From ff76d5b8da7b0821ec4470abb1e287cd57f43300 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Sat, 21 Nov 2020 03:11:26 -0500 Subject: [PATCH 2/2] fix namespacing --- src/init.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/init.jl b/src/init.jl index 722c6e1..2674d9f 100644 --- a/src/init.jl +++ b/src/init.jl @@ -4,7 +4,7 @@ function __init__() @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin wiener_randn!(rng::AbstractRNG,rand_vec::CuArrays.CuArray) = randn!(rand_vec) end - + @require CUDA="052768ef-5323-5732-b1bb-66c8b64840ba" begin wiener_randn!(rng::AbstractRNG,rand_vec::CUDA.CuArray) = randn!(rand_vec) end @@ -13,7 +13,7 @@ function __init__() @inline function DiffEqNoiseProcess.wiener_randn(rng::Random.AbstractRNG,proto::ReverseDiff.TrackedArray) ReverseDiff.track(convert.(eltype(proto.value),randn(rng,size(proto)))) end - @inline function DiffEqNoiseProcess.wiener_randn(rng::Random.AbstractRNG,proto::AbstractArray{<:TrackedReal}) + @inline function DiffEqNoiseProcess.wiener_randn(rng::Random.AbstractRNG,proto::AbstractArray{<:ReverseDiff.TrackedReal}) ReverseDiff.track.(randn.(rng,eltype(DiffEqBase.value.(proto)))) end end