Skip to content

Commit

Permalink
ehn: allow for batch solving for dwds
Browse files Browse the repository at this point in the history
  • Loading branch information
slibkind committed Nov 6, 2021
1 parent 86c02b8 commit 2d5f9ea
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 35 deletions.
52 changes: 29 additions & 23 deletions src/dwd_dynam.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ noutputs(interface::AbstractDirectedInterface) = length(output_ports(interface))

ndims(::DirectedVectorInterface{T, N}) where {T,N} = N

unit(::Type{I}, dims) where {T, I<:AbstractDirectedInterface{T}} = zeros(T, dims)
unit(::Type{DirectedVectorInterface{T,N}}, dims) where {T, N} = fill(zeros(T, N), dims)
unit(::Type{I}, shape) where {T, I<:AbstractDirectedInterface{T}} = fill(zero(T), shape)
unit(::Type{DirectedVectorInterface{T,N}}, shape) where {T, N} = fill(zeros(T, N), shape)


### Dynamics
Expand Down Expand Up @@ -197,20 +197,21 @@ Evaluates the dynamics of the machine `m` at state `u`, parameters `p`, and time
The length of `xs` must equal the number of inputs to `m`.
"""

eval_dynamics(f::DelayMachine, u, xs, h, p=nothing, t=0) = begin
ninputs(f) == length(xs) || error("$xs must have length $(ninputs(f)) to set the exogenous variables.")
#ninputs(f) == length(xs) || error("$xs must have length $(ninputs(f)) to set the exogenous variables.")
dynamics(f)(collect(u), collect(xs), h, p, t)
end

eval_dynamics(f::AbstractMachine, u, xs, p=nothing, t=0) = begin
ninputs(f) == length(xs) || error("$xs must have length $(ninputs(f)) to set the exogenous variables.")
#ninputs(f) == length(xs) || error("$xs must have length $(ninputs(f)) to set the exogenous variables.")
dynamics(f)(collect(u), collect(xs), p, t)
end

# eval_dynamics(f::AbstractMachine, u::S, xs::T, args...) where {S,T <: Union{FinDomFunction, AbstractVector}} =
# eval_dynamics(f, collect(u), collect(xs), args...)

eval_dynamics(f::AbstractMachine, u::AbstractVector, xs::AbstractVector{T}, p=nothing, t=0) where T <: Function =
eval_dynamics(f::AbstractMachine, u, xs::AbstractVector{T}, p=nothing, t=0) where T <: Function =
eval_dynamics(f, u, [x(t) for x in xs], p, t)

""" euler_approx(m::ContinuousMachine, h)
Expand Down Expand Up @@ -245,13 +246,13 @@ euler_approx(fs::AbstractDict{S, M}, args...) where {S, M<:ContinuousMachine} =
Constructs an ODEProblem from the vector field defined by `(u,p,t) -> m.dynamics(u, x, p, t)`. The exogenous variables are determined by `xs`.
"""
ODEProblem(m::ContinuousMachine{T}, u0, xs::AbstractVector, tspan, p=nothing; kwargs...) where T=
ODEProblem(m::ContinuousMachine{T}, u0, xs, tspan, p=nothing; kwargs...) where T=
ODEProblem((u,p,t) -> eval_dynamics(m, u, xs, p, t), u0, tspan, p; kwargs...)

ODEProblem(m::ContinuousMachine{T}, u0, x::Union{T, Function}, tspan, p=nothing; kwargs...) where T=
ODEProblem(m, u0, collect(repeated(x, ninputs(m))), tspan, p; kwargs...)

ODEProblem(m::ContinuousMachine{T}, u0, tspan, p=nothing; kwargs...) where T =
ODEProblem(m::ContinuousMachine{T}, u0, tspan::Tuple, p=nothing; kwargs...) where T =
ODEProblem(m, u0, T[], tspan, p; kwargs...)

""" DDEProblem(m::DelayMachine, u0::Vector, xs::Vector, h::Function, tspan, p = nothing; kwargs...)
Expand Down Expand Up @@ -346,26 +347,27 @@ end

function induced_dynamics(d::WiringDiagram, ms::Vector{M}, S, Inputs) where {T,I, M<:AbstractMachine{T,I}}

function v(u::AbstractVector, xs::AbstractVector, p, t::Real)
function v(u, xs, p, t)
states = destruct(S, u) # a list of the states by box
readouts = get_readouts(ms, states, p, t)
readins = unit(I, length(apex(Inputs)))
readin_shape = u isa AbstractVector ? length(apex(Inputs)) : (length(apex(Inputs)), size(u,2))
readins = unit(I, readin_shape)
fill_readins!(readins, d, Inputs, readouts, xs)

reduce(vcat, map(enumerate(destruct(Inputs, readins))) do (i,x)
eval_dynamics(ms[i], states[i], x, p, t)
end)
end

end

function induced_dynamics(d::WiringDiagram, ms::Vector{M}, S, Inputs) where {T,I, M<:DelayMachine{T,I}}

function v(u::AbstractVector, xs::AbstractVector, h, p, t::Real)
function v(u, xs, h, p, t)
states = destruct(S, u) # a list of the states by box
hists = destruct(S, h)
readouts = get_readouts(ms, states, hists, p, t)
readins = unit(I, length(apex(Inputs)))
readin_shape = u isa AbstractVector ? length(apex(Inputs)) : (length(apex(Inputs)), size(u,2))
readins = unit(I, readin_shape)
fill_readins!(readins, d, Inputs, readouts, xs)

reduce(vcat, map(enumerate(destruct(Inputs, readins))) do (i,x)
Expand All @@ -375,20 +377,22 @@ function induced_dynamics(d::WiringDiagram, ms::Vector{M}, S, Inputs) where {T,I
end

function induced_readout(d::WiringDiagram, ms::Vector{M}, S) where {T, I, M<:AbstractMachine{T,I}}
function r(u::AbstractVector, p, t)
function r(u, p, t)
states = destruct(S, u)
readouts = get_readouts(ms, states, p, t)
outputs = unit(I, length(output_ports(d)))
readout_shape = u isa AbstractVector ? length(output_ports(d)) : (length(output_ports(d)), size(u,2))
outputs = unit(I, readout_shape)
fill_outputs!(outputs, d, readouts)
end
end

function induced_readout(d::WiringDiagram, ms::Vector{M}, S) where {T, I, M<:DelayMachine{T,I}}
function r(u::AbstractVector, h, p, t)
function r(u, h, p, t)
states = destruct(S, u)
hists = destruct(S, h)
readouts = get_readouts(ms, states, hists, p, t)
outputs = unit(I, length(output_ports(d)))
readout_shape = u isa AbstractVector ? length(output_ports(d)) : (length(output_ports(d)), size(u,2))
outputs = unit(I, readout_shape)
fill_outputs!(outputs, d, readouts)
end
end
Expand All @@ -405,10 +409,13 @@ function fills(m::AbstractMachine, d::WiringDiagram, b::Int)
end


destruct(C::Colimit, xs::FinDomFunction) = map(1:length(C)) do i
collect(compose(legs(C)[i], xs))
destruct(C::Colimit, xs::AbstractVector) = map(1:length(C)) do i
xs[legs(C)[i].func]
end

destruct(C::Colimit, xs::AbstractMatrix) = map(1:length(C)) do i
xs[legs(C)[i].func, :]
end
destruct(C::Colimit, xs::AbstractVector) = destruct(C, FinDomFunction(xs))

destruct(C::Colimit, h) = map(1:length(C)) do i
(p,t) -> destruct(C, h(p,t))[i]
Expand All @@ -424,12 +431,11 @@ end


function fill_readins!(readins, d::WiringDiagram, Inputs::Colimit, readouts, xs)

for w in wires(d, :Wire)
readins[legs(Inputs)[w.target.box](w.target.port)] += readouts[w.source.box][w.source.port]
readins[legs(Inputs)[w.target.box](w.target.port), :] += readouts[w.source.box][w.source.port, :]
end
for w in wires(d, :InWire)
readins[legs(Inputs)[w.target.box](w.target.port)] += xs[w.source.port]
readins[legs(Inputs)[w.target.box](w.target.port), :] += xs[w.source.port, :]
end

return readins
Expand All @@ -439,7 +445,7 @@ end
function fill_outputs!(outs, d::WiringDiagram, readouts)

for w in wires(d, :OutWire)
outs[w.target.port] += readouts[w.source.box][w.source.port]
outs[w.target.port, :] += readouts[w.source.box][w.source.port, :]
end

return outs
Expand Down
9 changes: 7 additions & 2 deletions test/dwd_dynam.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ add_wires!(d_big, Pair[

@testset "ODE Problems" begin
# Identity
uf(u, x, p, t) = [x[1] - u[1]]
uf(u, x, p, t) = x - u
rf(u, args...) = u
mf = ContinuousMachine{Float64}(1,1,1, uf, rf)

Expand All @@ -53,7 +53,7 @@ add_wires!(d_big, Pair[
@test eval_dynamics(m_id, [x0], [p0]) == [p0 - x0]
@test readout(m_id, [x0]) == [x0]

# unfed parameter
# compose
m12 = oapply(d12, [mf, mf])

x0 = -1
Expand All @@ -62,6 +62,11 @@ add_wires!(d_big, Pair[
@test eval_dynamics(m12, [x0, y0], [p0]) == [p0 - x0, x0 - y0]
@test readout(m12, [x0,y0]) == [y0]

# test batch
batch_size = 10
us = reshape(1:(2*batch_size), 2, batch_size)
xs = reshape(1:batch_size, 1, batch_size)
@test eval_dynamics(m12, us, xs) == (hcat(collect(1:batch_size) .- collect(us[1,:]), collect(us[1,:]) - collect(us[2,:])) |> transpose)

# break and back together
m = oapply(d_copymerge, Dict(:f => mf))
Expand Down
20 changes: 10 additions & 10 deletions test/trajectories.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ approx_equal(u, v) = abs(maximum(u - v)) < .1
dx(x) = [1 - x[1]^2, 2*x[1]-x[2]]
dy(y) = [1 - y[1]^2]

r = ContinuousResourceSharer{Real}(2, (u,p,t) -> dx(u))
r = ContinuousResourceSharer{Float64}(2, (u,p,t) -> dx(u))

u0 = [-1.0, -2.0]
tspan = (0.0, 100.0)
Expand All @@ -43,7 +43,7 @@ t = solve(dds, FunctionMap())



dr = DiscreteResourceSharer{Real}(1, (u,p,t) -> dy(u))
dr = DiscreteResourceSharer{Float64}(1, (u,p,t) -> dy(u))
u0 = [1.0]
dds = DiscreteProblem(dr, u0, tspan, nothing)
t = solve(dds, FunctionMap())
Expand Down Expand Up @@ -109,9 +109,9 @@ dotr(u,p,t) = p[1]*u
dotrf(u,p,t) = [-p[2]*u[1]*u[2], p[3]*u[1]*u[2]]
dotf(u,p,t) = -p[4]*u

r = ContinuousResourceSharer{Real}(1, dotr)
rf_pred = ContinuousResourceSharer{Real}(2, dotrf)
f = ContinuousResourceSharer{Real}(1, dotf)
r = ContinuousResourceSharer{Float64}(1, dotr)
rf_pred = ContinuousResourceSharer{Float64}(2, dotrf)
f = ContinuousResourceSharer{Float64}(1, dotf)


rf_pattern = UWD(0)
Expand Down Expand Up @@ -142,8 +142,8 @@ end)
dotr(u, x, p, t) = [p[1]*u[1] - p[2]*u[1]*x[1]]
dotf(u, x, p, t) = [p[3]*u[1]*x[1] - p[4]*u[1]]

rmachine = ContinuousMachine{Real}(1,1,1, dotr, (r,p,t) -> r)
fmachine = ContinuousMachine{Real}(1,1,1, dotf, (f,p,t) -> f)
rmachine = ContinuousMachine{Float64}(1,1,1, dotr, (r,p,t) -> r)
fmachine = ContinuousMachine{Float64}(1,1,1, dotf, (f,p,t) -> f)

rf_pattern = WiringDiagram([],[])
boxr = add_box!(rf_pattern, Box(nothing, [nothing], [nothing]))
Expand Down Expand Up @@ -180,9 +180,9 @@ dotfish(f, x, p, t) = [p[1]*f[1] - p[2]*x[1]*f[1]]
dotFISH(F, x, p, t) = [p[3]*x[1]*F[1] - p[4]*F[1] - p[5]*x[2]*F[1]]
dotsharks(s, x, p, t) = [-p[7]*s[1] + p[6]*s[1]*x[1]]

fish = ContinuousMachine{Real}(1,1,1, dotfish, (f,p,t) ->f)
FISH = ContinuousMachine{Real}(2,1,2, dotFISH, (F,p,t)->[F[1], F[1]])
sharks = ContinuousMachine{Real}(1,1,1, dotsharks, (s,p,t)->s)
fish = ContinuousMachine{Float64}(1,1,1, dotfish, (f,p,t) ->f)
FISH = ContinuousMachine{Float64}(2,1,2, dotFISH, (F,p,t)->[F[1], F[1]])
sharks = ContinuousMachine{Float64}(1,1,1, dotsharks, (s,p,t)->s)

ocean_pat = WiringDiagram([], [])
boxf = add_box!(ocean_pat, Box(nothing, [nothing], [nothing]))
Expand Down

0 comments on commit 2d5f9ea

Please sign in to comment.