Skip to content

Commit

Permalink
Merge pull request #483 from SciML/myb/fix
Browse files Browse the repository at this point in the history
Fix ODEFunction constructor
  • Loading branch information
ChrisRackauckas authored Aug 22, 2023
2 parents a6b09ce + 8d7de64 commit 89a2906
Show file tree
Hide file tree
Showing 8 changed files with 55 additions and 40 deletions.
22 changes: 11 additions & 11 deletions ext/ZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@ using SciMLBase: ODESolution, issymbollike, sym_to_index, remake, getobserved
# https://github.com/SciML/RecursiveArrayTools.jl/blob/d06ecb856f43bc5e37cbaf50e5f63c578bf3f1bd/ext/RecursiveArrayToolsZygoteExt.jl#L67
@adjoint function getindex(VA::ODESolution, i::Int, j::Int)
function ODESolution_getindex_pullback(Δ)
du = [m == j ? [i == k ? Δ : zero(VA.u[1][1]) for k = 1:length(VA.u[1])] :
zero(VA.u[1]) for m = 1:length(VA.u)]
du = [m == j ? [i == k ? Δ : zero(VA.u[1][1]) for k in 1:length(VA.u[1])] :
zero(VA.u[1]) for m in 1:length(VA.u)]
dp = zero(VA.prob.p)
dprob = remake(VA.prob, p = dp)
du, dprob
T = eltype(eltype(VA.u))
N = length(VA.prob.p)
Δ′ = ODESolution{T,N,typeof(du),Nothing,Nothing,typeof(VA.t),
typeof(VA.k),typeof(dprob),typeof(VA.alg),typeof(VA.interp),
typeof(VA.destats),typeof(VA.alg_choice)}(du, nothing, nothing,
Δ′ = ODESolution{T, N, typeof(du), Nothing, Nothing, typeof(VA.t),
typeof(VA.k), typeof(dprob), typeof(VA.alg), typeof(VA.interp),
typeof(VA.destats), typeof(VA.alg_choice)}(du, nothing, nothing,
VA.t, VA.k, dprob, VA.alg, VA.interp, VA.dense, 0, VA.destats,
VA.alg_choice, VA.retcode)
(Δ′, nothing, nothing)
Expand All @@ -32,22 +32,22 @@ end
du, dprob = if i === nothing
getter = getobserved(VA)
grz = pullback(getter, sym, VA.u[j], VA.prob.p, VA.t[j])[2](Δ)
du = [k == j ? grz[2] : zero(VA.u[1]) for k = 1:length(VA.u)]
du = [k == j ? grz[2] : zero(VA.u[1]) for k in 1:length(VA.u)]
dp = grz[3] # pullback for p
dprob = remake(VA.prob, p = dp)
du, dprob
else
du = [m == j ? [i == k ? Δ : zero(VA.u[1][1]) for k = 1:length(VA.u[1])] :
zero(VA.u[1]) for m = 1:length(VA.u)]
du = [m == j ? [i == k ? Δ : zero(VA.u[1][1]) for k in 1:length(VA.u[1])] :
zero(VA.u[1]) for m in 1:length(VA.u)]
dp = zero(VA.prob.p)
dprob = remake(VA.prob, p = dp)
du, dprob
end
T = eltype(eltype(VA.u))
N = length(VA.prob.p)
Δ′ = ODESolution{T,N,typeof(du),Nothing,Nothing,typeof(VA.t),
typeof(VA.k),typeof(dprob),typeof(VA.alg),typeof(VA.interp),
typeof(VA.destats),typeof(VA.alg_choice)}(du, nothing, nothing,
Δ′ = ODESolution{T, N, typeof(du), Nothing, Nothing, typeof(VA.t),
typeof(VA.k), typeof(dprob), typeof(VA.alg), typeof(VA.interp),
typeof(VA.destats), typeof(VA.alg_choice)}(du, nothing, nothing,
VA.t, VA.k, dprob, VA.alg, VA.interp, VA.dense, 0, VA.destats,
VA.alg_choice, VA.retcode)
(Δ′, nothing, nothing)
Expand Down
2 changes: 1 addition & 1 deletion src/ensemble/basic_ensemble_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ function __solve(prob::EnsembleProblem{<:AbstractVector{<:AbstractSciMLProblem}}
ensemblealg::BasicEnsembleAlgorithm; kwargs...)
# TODO: @invoke
invoke(__solve, Tuple{AbstractEnsembleProblem, typeof(alg), typeof(ensemblealg)},
prob, alg, ensemblealg; trajectories=length(prob.prob), kwargs...)
prob, alg, ensemblealg; trajectories = length(prob.prob), kwargs...)
end

function __solve(prob::AbstractEnsembleProblem,
Expand Down
33 changes: 20 additions & 13 deletions src/ensemble/ensemble_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@ DEFAULT_REDUCTION(u, data, I) = append!(u, data), false
DEFAULT_VECTOR_PROB_FUNC(prob, i, repeat) = prob[i]
function EnsembleProblem(prob::AbstractVector{<:AbstractSciMLProblem}; kwargs...)
# TODO: @invoke
invoke(EnsembleProblem, Tuple{Any}, prob; prob_func=DEFAULT_VECTOR_PROB_FUNC, kwargs...)
invoke(EnsembleProblem,
Tuple{Any},
prob;
prob_func = DEFAULT_VECTOR_PROB_FUNC,
kwargs...)
end
function EnsembleProblem(prob;
output_func = DEFAULT_OUTPUT_FUNC,
Expand All @@ -36,20 +40,23 @@ function EnsembleProblem(; prob,
EnsembleProblem(prob, prob_func, output_func, reduction, u_init, safetycopy)
end

struct WeightedEnsembleProblem{T1<:AbstractEnsembleProblem, T2<:AbstractVector} <: AbstractEnsembleProblem
ensembleprob::T1
weights::T2
struct WeightedEnsembleProblem{T1 <: AbstractEnsembleProblem, T2 <: AbstractVector} <:
AbstractEnsembleProblem
ensembleprob::T1
weights::T2
end
function Base.propertynames(e::WeightedEnsembleProblem)
(Base.propertynames(getfield(e, :ensembleprob))..., :weights)
end
Base.propertynames(e::WeightedEnsembleProblem) = (Base.propertynames(getfield(e, :ensembleprob))..., :weights)
function Base.getproperty(e::WeightedEnsembleProblem, f::Symbol)
f === :weights && return getfield(e, :weights)
f === :ensembleprob && return getfield(e, :ensembleprob)
return getproperty(getfield(e, :ensembleprob), f)
f === :weights && return getfield(e, :weights)
f === :ensembleprob && return getfield(e, :ensembleprob)
return getproperty(getfield(e, :ensembleprob), f)
end
function WeightedEnsembleProblem(args...; weights, kwargs...)
# TODO: allow skipping checks?
@assert sum(weights) 1
ep = EnsembleProblem(args...; kwargs...)
@assert length(ep.prob) == length(weights)
WeightedEnsembleProblem(ep, weights)
# TODO: allow skipping checks?
@assert sum(weights) 1
ep = EnsembleProblem(args...; kwargs...)
@assert length(ep.prob) == length(weights)
WeightedEnsembleProblem(ep, weights)
end
13 changes: 9 additions & 4 deletions src/ensemble/ensemble_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ function EnsembleSolution(sim::T, elapsedTime,
converged)
end

struct WeightedEnsembleSolution{T1<:AbstractEnsembleSolution, T2<:Number}
struct WeightedEnsembleSolution{T1 <: AbstractEnsembleSolution, T2 <: Number}
ensol::T1
weights::Vector{T2}
function WeightedEnsembleSolution(ensol, weights)
Expand Down Expand Up @@ -207,13 +207,18 @@ end
end
end


Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, s, ::Colon)
return [xi[s] for xi in x]
end

Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Colon, args::Colon...)
return invoke(getindex, Tuple{RecursiveArrayTools.AbstractVectorOfArray, Colon, typeof.(args)...}, x, :, args...)
Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution,
::Colon,
args::Colon...)
return invoke(getindex,
Tuple{RecursiveArrayTools.AbstractVectorOfArray, Colon, typeof.(args)...},
x,
:,
args...)
end

function (sol::AbstractEnsembleSolution)(args...; kwargs...)
Expand Down
11 changes: 7 additions & 4 deletions src/scimlfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,8 @@ See the `modelingtoolkitize` function from
automatically symbolically generating the Jacobian and more from the
numerically-defined functions.
"""
struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, WP, TPJ, S,
struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, WP, TPJ,
S,
S2, S3, O, TCV,
SYS} <: AbstractODEFunction{iip}
f::F
Expand Down Expand Up @@ -2253,13 +2254,14 @@ function ODEFunction{iip, specialize}(f;
typeof(_colorvec),
typeof(sys)}(f, mass_matrix, analytic, tgrad, jac,
jvp, vjp, jac_prototype, sparsity, Wfact,
Wfact_t, paramjac, syms, indepsym, paramsyms,
Wfact_t, W_prototype, paramjac, syms, indepsym, paramsyms,
observed, _colorvec, sys)
elseif specialize === false
ODEFunction{iip, FunctionWrapperSpecialize,
typeof(f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype),
typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(W_prototype), typeof(paramjac),
typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(W_prototype),
typeof(paramjac),
typeof(syms), typeof(indepsym), typeof(paramsyms), typeof(observed),
typeof(_colorvec),
typeof(sys)}(f, mass_matrix, analytic, tgrad, jac,
Expand All @@ -2270,7 +2272,8 @@ function ODEFunction{iip, specialize}(f;
ODEFunction{iip, specialize,
typeof(f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype),
typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(W_prototype), typeof(paramjac),
typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(W_prototype),
typeof(paramjac),
typeof(syms), typeof(indepsym), typeof(paramsyms), typeof(observed),
typeof(_colorvec),
typeof(sys)}(f, mass_matrix, analytic, tgrad, jac,
Expand Down
10 changes: 5 additions & 5 deletions test/downstream/ensemble_multi_prob.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ using ModelingToolkit, OrdinaryDiffEq, Test
D = Differential(t)

@named sys1 = ODESystem([D(x) ~ x,
D(y) ~ -y])
D(y) ~ -y])
@named sys2 = ODESystem([D(x) ~ 2x,
D(y) ~ -2y])
D(y) ~ -2y])
@named sys3 = ODESystem([D(x) ~ 3x,
D(y) ~ -3y])
D(y) ~ -3y])

prob1 = ODEProblem(sys1, [1.0, 1.0], (0.0, 1.0))
prob2 = ODEProblem(sys2, [2.0, 2.0], (0.0, 1.0))
Expand All @@ -22,6 +22,6 @@ for i in 1:3
@test sol[y, :][i] == sol[i][y]
end
# Ensemble is a recursive array
@test only.(sol(0.0, idxs=[x])) == sol[1, 1, :] == first.(sol[x, :])
@test only.(sol(0.0, idxs = [x])) == sol[1, 1, :] == first.(sol[x, :])
# TODO: fix the interpolation
@test only.(sol(1.0, idxs=[x])) last.(sol[x, :])
@test only.(sol(1.0, idxs = [x])) last.(sol[x, :])
2 changes: 1 addition & 1 deletion test/downstream/remake_autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ if VERSION >= v"1.9"
_prob = remake(prob, u0 = u0, p = p)
soln = solve(_prob, Tsit5(), reltol = 1e-6, abstol = 1e-6, saveat = 0.1,
sensealg = BacksolveAdjoint(autojacvec = ZygoteVJP()))
sum(soln[o, i] for i = 1:length(soln))
sum(soln[o, i] for i in 1:length(soln))
end

du01, dp1 = Zygote.gradient(symbolic_indexing_observed, u0, p)
Expand Down
2 changes: 1 addition & 1 deletion test/solution_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ end
ode = ODEProblem(f, 1.0, (0.0, 1.0))
sol = SciMLBase.build_solution(ode, :NoAlgorithm, [ode.tspan[begin]], [ode.u0])
@test sol(0.0) == 1.0
@test sol([0.0,0.0]) == [1.0, 1.0]
@test sol([0.0, 0.0]) == [1.0, 1.0]
# test that indexing out of bounds doesn't segfault
@test_throws ErrorException sol(1)
@test_throws ErrorException sol(-0.5)
Expand Down

0 comments on commit 89a2906

Please sign in to comment.