Skip to content

Commit

Permalink
add manual example
Browse files Browse the repository at this point in the history
  • Loading branch information
matthieugomez committed Jun 13, 2024
1 parent a9caa90 commit 7c1ad13
Show file tree
Hide file tree
Showing 6 changed files with 165 additions and 12 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ julia = "1.2"

[extras]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
InfinitesimalGenerators = "2fce0c6f-5f0b-5c85-85c9-2ffe1d5ee30d"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "Distributions"]
test = ["Test", "Distributions", "InfinitesimalGenerators"]
58 changes: 50 additions & 8 deletions examples/ConsumptionProblem/WangWangYang.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,38 +7,80 @@ Base.@kwdef mutable struct WangWangYangModel
ρ::Float64 = 0.04
γ::Float64 = 3.0
ψ::Float64 = 1.1
wmax::Float64 = 5000.0
wmin::Float64 = 0.0
wmax::Float64 = 1000.0
end


function (m::WangWangYangModel)(state::NamedTuple, y::NamedTuple)
(; μ, σ, r, ρ, γ, ψ, wmax) = m
(; μ, σ, r, ρ, γ, ψ, wmin, wmax) = m
(; w) = state
(; p, pw_up, pw_down, pww) = y
pw = pw_up
iter = 0
@label start
pw = max(pw, sqrt(eps()))
c = (r + ψ *- r)) * p * pw^(-ψ)
μw = (r - μ + σ^2) * w + 1 - c
if (iter == 0) & (μw <= 0)
iter += 1
pw = pw_down
@goto start
end

# One only needs a ghost node if μw <= 0 (since w^2p_ww = 0). In this case, we obtain a formula for pw so that c <= 1
if w 0.0 && μw <= 0.0
pw = ((r + ψ *- r)) * p)^(1 / ψ)
c = 1.0
μw = 0.0
if w wmin && μw <= 0.0
μw = 0.0
c = 1.0
pw = (c / ((r + ψ *- r))))^(-1 / ψ)
end
# At the top, I use the solution of the unconstrainted, i.e. pw = 1 (I could also do reflecting boundary but less elegant)
pt = - ((((r + ψ *- r)) * pw^(1 - ψ) - ψ * ρ) /- 1) + μ - γ * σ^2 / 2) * p + ((r - μ + γ * σ^2) * w + 1) * pw + σ^2 * w^2 / 2 * (pww - γ * pw^2 / p))
return (; pt)
end

m = WangWangYangModel()
stategrid = OrderedDict(:w => range(0.0, m.wmax, length = 100))
stategrid = OrderedDict(:w => range(m.wmin^(1/2), m.wmax^(1/2), length = 100).^2)
yend = OrderedDict(:p => 1 .+ stategrid[:w])
result = pdesolve(m, stategrid, yend, bc = OrderedDict(:pw => (1.0, 1.0)))
@assert result.residual_norm <= 1e-5





# Alternative solution bypassing pdesolve
# just encode the PDE has a vector equation
using InfinitesimalGenerators
function solve!(pts, m, ws, ps)
(; μ, σ, r, ρ, γ, ψ, wmin, wmax) = m
pw_ups = FirstDerivative(ws, ps; direction = :upward, bc = (0.0, 1.0))
pw_downs = FirstDerivative(ws, ps; direction = :downward, bc = (0.0, 1.0))
pwws = SecondDerivative(ws, ps, bc = (0.0, 1.0))
for i in eachindex(ws)
w = ws[i]
p, pw_up, pw_down, pww = ps[i], pw_ups[i], pw_downs[i], pwws[i]
pw = pw_up
iter = 0
@label start
pw = max(pw, sqrt(eps()))
c = (r + ψ *- r)) * p * pw^(-ψ)
μw = (r - μ + σ^2) * w + 1 - c
if (iter == 0) & (μw <= 0)
iter += 1
pw = pw_down
@goto start
end
# One only needs a ghost node if μw <= 0 (since w^2p_ww = 0). In this case, we obtain a formula for pw so that c <= 1
if w wmin && μw <= 0.0
μw = 0.0
c = 1.0
pw = (c / ((r + ψ *- r))))^(-1 / ψ)
end
pts[i] = - ((((r + ψ *- r)) * pw^(1 - ψ) - ψ * ρ) /- 1) + μ - γ * σ^2 / 2) * p + ((r - μ + γ * σ^2) * w + 1) * pw + σ^2 * w^2 / 2 * (pww - γ * pw^2 / p))
end
return pts
end
m = WangWangYangModel()
ws = range(m.wmin^(1/2), m.wmax^(1/2), length = 100).^2
ps = 1 .+ stategrid[:w]
finiteschemesolve((ydot, y) -> solve!(ydot, m, ws, y), ps)
7 changes: 6 additions & 1 deletion src/finiteschemesolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ function finiteschemesolve(G!, y0; Δ = 1.0, is_algebraic = fill(false, size(y0)
G!(ydot, ypost)
residual_norm = norm(ydot) / length(ydot)
isnan(residual_norm) && throw("G! returns NaN with the initial value")
if residual_norm <= maxdist
verbose && @warn "G! already returns zero with the initial value"
return ypost, residual_norm
end
if Δ == Inf
ypost, residual_norm = implicit_timestep(G!, y0, Δ; is_algebraic = is_algebraic, verbose = verbose, iterations = iterations, method = method, autodiff = autodiff, maxdist = maxdist, J0c = J0c, y̲ = y̲, ȳ = ȳ)
else
Expand Down Expand Up @@ -85,7 +89,8 @@ function finiteschemesolve(G!, y0; Δ = 1.0, is_algebraic = fill(false, size(y0)
end
end
end
verbose && ((residual_norm > maxdist) |< minΔ)) && @warn "Iteration did not converge"
verbose && (iter >= iterations) && @warn "Algorithm did not converge: Iter higher than the limit $(iterations)"
verbose &&< minΔ) && @warn "Algorithm did not converge: TimeStep lower than the limit $(minΔ)"
return ypost, residual_norm
end

Expand Down
72 changes: 72 additions & 0 deletions src/pdesolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,75 @@ function _setindex!(@nospecialize(a), apm, stategrid::StateGrid, Tsolution, y_M:
end





function pdesolve2(apm, @nospecialize(grid), @nospecialize(yend); is_algebraic = OrderedDict(k => false for k in keys(yend)), bc = nothing, verbose = true, kwargs...)
stategrid = StateGrid(NamedTuple(grid))
S = size(stategrid)
all(size(v) == S for v in values(yend)) || throw(ArgumentError("The length of initial guess (e.g. terminal value) does not equal the length of the state space"))
all(keys(is_algebraic) .== keys(yend)) || throw(ArgumentError("the terminal guess yend and the is_algebric keyword argument must have the same names"))
is_algebraic = OrderedDict(first(p) => fill(last(p), S) for p in pairs(is_algebraic))
y = OrderedDict(first(p) => collect(last(p)) for p in pairs(yend))
# convert to Matrix
yend_M = catlast(values(yend))
is_algebraic_M = catlast(values(is_algebraic))
bc_M = _Array_bc(bc, yend, grid)
Tsolution = Type{tuple(keys(yend)...)}

a = get_a2(apm, stategrid, Tsolution, yend_M, bc_M)
y_M, residual_norm = finiteschemesolve((ydot, y) -> hjb2!(apm, stategrid, Tsolution, ydot, y, bc_M, size(yend_M)), vec(yend_M); is_algebraic = vec(is_algebraic_M), J0c = J0c, verbose = verbose, kwargs... )
y_M = reshape(y_M, size(yend_M)...)
_setindex!(y, y_M)
if a !== nothing
_setindex2!(a, apm, stategrid, Tsolution, y_M, bc_M)
a = merge(y, a)
end
return EconPDEResult(y, residual_norm, a)
end

function get_a2(apm, stategrid::StateGrid, Tsolution, y_M::AbstractArray, bc_M::AbstractArray)
derivatives = differentiate2(Tsolution, stategrid, y_M, bc_M)
result = apm(stategrid, derivatives)
if length(result) == 1
return nothing
else
return OrderedDict(a_key => Array{Float64}(undef, size(stategrid)) for a_key in keys(result[2]))
end
end

# create hjb! that accepts and returns AbstractVector rather than AbstractArrays
function hjb2!(apm, stategrid::StateGrid, Tsolution, ydot::AbstractVector, y::AbstractVector, bc_M::AbstractArray, ysize::NTuple)
y_M = reshape(y, ysize...)
ydot_M = reshape(ydot, ysize...)
vec(hjb!(apm, stategrid, Tsolution, ydot_M, y_M, bc_M))
end

function hjb2!(apm, stategrid::StateGrid, Tsolution, ydot_M::AbstractArray, y_M::AbstractArray, bc_M::AbstractArray)
solution = differentiate2(Tsolution, stategrid, y_M, bc_M)
out = apm(stategrid, solution)
if isa(outi[1], Vector)
_setindex2!(ydot_M, Tsolution, outi, i)
else
_setindex2!(ydot_M, Tsolution, outi[1], i)
end
return ydot_M
end


@generated function _setindex2!(ydot_M::AbstractArray, ::Type{Tsolution}, outi::NamedTuple) where {Tsolution}
N = length(Tsolution.parameters[1])
quote
$(Expr(:meta, :inline))
$(Expr(:block, [Expr(:call, :setindex!, :ydot_M, Expr(:call, :getproperty, :outi, Meta.quot(Symbol(Tsolution.parameters[1][k], :t))), :Colon(), k) for k in 1:N]...))
end
end


function _setindex2!(@nospecialize(a), apm, stategrid::StateGrid, Tsolution, y_M::AbstractArray, bc_M::AbstractArray)
solution = differentiate2(Tsolution, stategrid, y_M, bc_M)
out = apm(stategrid, solution)
for (k, v) in zip(values(a), values(outi))
copyto!(k, v)
end
end
35 changes: 35 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,38 @@ end
end
end



# 1 state variable
@generated function differentiate2(::Type{Tsolution}, grid::StateGrid{T1, 1, <: NamedTuple{N}}, y::AbstractArray{T}, bc) where {Tsolution, T1, N, T}
statename = N[1]
expr = Expr[]
for k in 1:length(Tsolution.parameters[1])
solname = Tsolution.parameters[1][k]
push!(expr, Expr(:(=), solname, :(v[$k])))
push!(expr, Expr(:(=), Symbol(solname, statename, :_up), :(va_up[$k])))
push!(expr, Expr(:(=), Symbol(solname, statename, :_down), :(va_down[$k])))
push!(expr, Expr(:(=), Symbol(solname, statename, statename), :(vaa[$k])))
end
quote
$(Expr(:meta, :inline))
K = length(Tsolution.parameters[1])
v = [zeros(length(grid.x)) for k in 1:K]
va_up = [zeros(length(grid.x)) for k in 1:K]
va_up = [zeros(length(grid.x)) for k in 1:K]
vaa = [zeros(length(grid.x)) for k in 1:K]
for i in 1:length(grid.x)
grida = grid.x[1]
Δxm = grida[max(i, 2)] - grida[max(i-1, 1)]
Δxp = grida[min(i+1, size(y, 1))] - grida[min(i, size(y, 1) - 1)]
Δx = (Δxm + Δxp) / 2
for k in 1:K
v[k][i] = y[i, k]
va_up[k][i] = (i < size(y, 1)) ? (y[i+1, k] - y[i, k]) / Δxp : convert($T, bc[end, k])
va_down[k][i] = ((i > 1) ? (y[i, k] - y[i-1, k]) / Δxm : convert($T, bc[1, k]))
vaa[k][i] = ((1 < i < size(y, 1)) ? (y[i + 1, k] / (Δxp * Δx) + y[i - 1, k] / (Δxm * Δx) - 2 * y[i, k] / (Δxp * Δxm)) : ((i == 1) ? (y[2, k] / (Δxp * Δx) + (y[1, k] - bc[1, k] * Δxm) / (Δxm * Δx) - 2 * y[1, k] / (Δxp * Δxm)) : ((y[end, k] + bc[end, k] * Δxp) / (Δxp * Δx) + y[end - 1, k] / (Δxm * Δx) - 2 * y[end, k] / (Δxp * Δxm))))
end
end
@inbounds $(Expr(:tuple, expr...))
end
end
2 changes: 0 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,3 @@ end
for x in (:Leland, )
@testset "$x" begin include("../examples/OptimalStoppingTime/$(x).jl") end
end


0 comments on commit 7c1ad13

Please sign in to comment.