Skip to content

Commit

Permalink
complete test
Browse files Browse the repository at this point in the history
  • Loading branch information
chooron committed Dec 11, 2024
1 parent df765c6 commit 35ff1bb
Show file tree
Hide file tree
Showing 17 changed files with 75 additions and 441 deletions.
8 changes: 6 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "HydroModels"
uuid = "7e3cde01-c141-467b-bff6-5350a0af19fc"
authors = ["jingx <[email protected]>"]
version = "0.1.0"
version = "0.1.1"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand Down Expand Up @@ -47,12 +47,16 @@ Symbolics = "6"
TOML = "1"
Test = "1"
julia = "1.10"
CSV = "0.10"
DataFrames = "1"
Statistics = "1"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "Aqua", "CSV", "DataFrames"]
test = ["Test", "Aqua", "CSV", "DataFrames", "Statistics"]
4 changes: 3 additions & 1 deletion src/HydroModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,11 @@ include("utils/name.jl")
include("utils/show.jl")
include("utils/build.jl")
include("utils/sort.jl")
include("utils/check.jl")
include("utils/io.jl")
inclue("utils/check.jl")
export NamedTupleIOAdapter
include("utils/solver.jl")
export ManualSolver

# framework build
include("flux.jl")
Expand Down
6 changes: 3 additions & 3 deletions src/bucket.jl
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,8 @@ function (ele::HydroBucket{F,D,FF,OF,M})(
params_vec, nn_params_vec = param_func(pas), nn_param_func(pas)
flux_output = ele.flux_func.(eachslice(input, dims=2), Ref(params_vec), Ref(nn_params_vec), timeidx)
#* convert vector{vector} to matrix
flux_output_matrix = reduce(hcat, flux_output)
flux_output_matrix
flux_output_mat = reduce(hcat, flux_output)
flux_output_mat
end

function (ele::HydroBucket{F,D,FF,OF,M})(
Expand All @@ -241,7 +241,7 @@ function (ele::HydroBucket{F,D,FF,OF,M})(
check_ptypes(ele, input, ptypes)
check_stypes(ele, input, stypes)
#* check initial states
check_initstates(ele, pas)
check_initstates(ele, pas, stypes)
#* prepare initial states
init_states_mat = reduce(hcat, [collect(pas[:initstates][stype][get_state_names(ele)]) for stype in stypes])
#* extract params and nn params
Expand Down
6 changes: 3 additions & 3 deletions src/route.jl
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ function (route::HydroRoute{F,PF,M})(
sol_arr = solver(du_func, pas, init_states_mat, timeidx, convert_to_array=true)
sol_arr_permuted = permutedims(sol_arr, (2, 1, 3))
cat_arr = cat(input, sol_arr_permuted, dims=1)
output_vec = [route.rfunc.func.(eachslice(cat_arr_, dims=2), param_func(pas), timeidx[i]) for cat_arr_ in eachslice(cat_arr, dims=3)]
output_vec = [route.rfunc.func.(eachslice(cat_arr[:, :, i], dims=2), param_func(pas), timeidx[i]) for i in axes(cat_arr, 3)]
out_arr = reduce(hcat, reduce.(vcat, output_vec))
#* return route_states and q_out
return cat(sol_arr_permuted, reshape(out_arr, 1, size(out_arr)...), dims=1)
Expand Down Expand Up @@ -345,8 +345,8 @@ function (route::RapidRoute)(
itp_funcs = interp.(eachslice(input[1, :, :], dims=1), Ref(timeidx), extrapolate=true)

#* prepare the parameters for the routing function
k_ps = [pas[:params][ptype][:k] for ptype in ptypes]
x_ps = [pas[:params][ptype][:x] for ptype in ptypes]
k_ps = [pas[:params][ptype][:rapid_k] for ptype in ptypes]
x_ps = [pas[:params][ptype][:rapid_x] for ptype in ptypes]
c0 = @. ((delta_t / k_ps) - (2 * x_ps)) / ((2 * (1 - x_ps)) + (delta_t / k_ps))
c1 = @. ((delta_t / k_ps) + (2 * x_ps)) / ((2 * (1 - x_ps)) + (delta_t / k_ps))
c2 = @. ((2 * (1 - x_ps)) - (delta_t / k_ps)) / ((2 * (1 - x_ps)) + (delta_t / k_ps))
Expand Down
13 changes: 6 additions & 7 deletions src/uh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,22 +151,21 @@ Apply the unit hydrograph flux model to input data of various dimensions.

(::UnitHydrograph)(::AbstractVector, ::ComponentVector; kwargs...) = @error "UnitHydrograph is not support for single timepoint"

function (flux::UnitHydrograph{<:Any,<:Any,<:Any,:DISCRETE})(input::AbstractArray{T,2}, pas::ComponentVector; kwargs...) where {T}
function (flux::UnitHydrograph{<:Any,<:Any,<:Any,:DISCRETE})(input::AbstractArray{T,2}, pas::ComponentVector; config::NamedTuple=NamedTuple(), kwargs...) where {T}
solver = get(config, :solver, ManualSolver{true}())
timeidx = get(kwargs, :timeidx, collect(1:size(input, 2)))
timeidx = get(config, :timeidx, collect(1:size(input, 2)))
input_vec = input[1, :]
#* convert the lagflux to a discrete problem
lag_du_func(u,p,t) = input_vec[Int(t)] .* p[:weight] .+ [diff(u); -u[end]]
lag_du_func(u, p, t) = input_vec[Int(t)] .* p[:weight] .+ [diff(u); -u[end]]
#* prepare the initial states
lag = pas[:params][get_param_names(flux)[1]]
uh_weight = map(t -> flux.uhfunc(t, lag), 1:get_uh_tmax(flux.uhfunc, lag))[1:end-1]
if length(uh_weight) == 0
@warn "The unit hydrograph weight is empty, please check the unit hydrograph function"
return input
else
initstates = input_vec[1] .* uh_weight ./ sum(uh_weight)
#* solve the problem
sol = solver(lag_du_func, ComponentVector(weight=uh_weight ./ sum(uh_weight)), initstates, timeidx)
sol = solver(lag_du_func, ComponentVector(weight=uh_weight ./ sum(uh_weight)), zeros(length(uh_weight)), timeidx)
reshape(sol[1, :], 1, length(input_vec))
end
end
Expand All @@ -191,9 +190,9 @@ function (flux::UnitHydrograph{<:Any,<:Any,<:Any,:SPARSE})(input::AbstractArray{
end

# todo: 卷积计算的结果与前两个计算结果不太一致
function (flux::UnitHydrograph{<:Any,<:Any,<:Any,:INTEGRAL})(input::AbstractArray{T,2}, pas::ComponentVector; kwargs...) where {T}
function (flux::UnitHydrograph{<:Any,<:Any,<:Any,:INTEGRAL})(input::AbstractArray{T,2}, pas::ComponentVector; config::NamedTuple=NamedTuple(), kwargs...) where {T}
input_vec = input[1, :]
itp_method = get(kwargs, :interp, LinearInterpolation)
itp_method = get(config, :interp, LinearInterpolation)
itp = itp_method(input_vec, collect(1:length(input_vec)), extrapolate=true)
#* construct the unit hydrograph function based on the interpolation method and parameter
lag = pas[:params][get_param_names(flux)[1]]
Expand Down
23 changes: 13 additions & 10 deletions src/utils/check.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ end

function check_pas(component::AbstractComponent, pas::ComponentVector)
check_parameters(component, pas)
check_states(component, pas)
check_initstates(component, pas)
check_nns(component, pas)
end

function check_pas(component::AbstractComponent, pas::ComponentVector, ptypes::AbstractVector{Symbol}, stypes::AbstractVector{Symbol})
check_parameters(component, pas, ptypes)
check_states(component, pas, stypes)
check_initstates(component, pas, stypes)
check_nns(component, pas)
end

Expand All @@ -42,31 +42,34 @@ function check_parameters(component::AbstractComponent, pas::ComponentVector, pt
param_names = get_param_names(component)
cpt_name = get_name(component)
for ptype in ptypes
tmp_ptype_params_keys = keys(pas[:params][ptype])
for param_name in param_names
@assert(param_name in keys(pas[ptype][:params]),
"Parameter '$(param_name)' in component '$(cpt_name)' is required but not found in parameter type '$(ptype)'. Available parameters: $(keys(pas[ptype][:params]))"
@assert(param_name in tmp_ptype_params_keys,
"Parameter '$(param_name)' in component '$(cpt_name)' is required but not found in parameter type '$(ptype)'. Available parameters: $(tmp_ptype_params_keys)"
)
end
end
end

function check_states(component::AbstractComponent, pas::ComponentVector)
function check_initstates(component::AbstractComponent, pas::ComponentVector)
state_names = get_state_names(component)
cpt_name = get_name(component)
for state_name in state_names
@assert(state_name in keys(pas[:initstates]),
"Initial state '$(state_name)' in component '$(cpt_name)' is required but not found in pas[:initstates]. Available states: $(keys(pas[:initstates]))"
tmp_ptype_initstates_keys = keys(pas[:initstates])
@assert(state_name in tmp_ptype_initstates_keys,
"Initial state '$(state_name)' in component '$(cpt_name)' is required but not found in parameter type '$(ptype)'. Available states: $(tmp_ptype_initstates_keys)"
)
end
end

function check_states(component::AbstractComponent, pas::ComponentVector, stypes::AbstractVector{Symbol})
function check_initstates(component::AbstractComponent, pas::ComponentVector, stypes::AbstractVector{Symbol})
state_names = get_state_names(component)
cpt_name = get_name(component)
for stype in stypes
tmp_ptype_initstates_keys = keys(pas[:initstates][stype])
for state_name in state_names
@assert(state_name in keys(pas[stype][:initstates]),
"Initial state '$(state_name)' in component '$(cpt_name)' is required but not found in state type '$(stype)'. Available states: $(keys(pas[stype][:initstates]))"
@assert(state_name in tmp_ptype_initstates_keys,
"Initial state '$(state_name)' in component '$(cpt_name)' is required but not found in state type '$(stype)'. Available states: $(tmp_ptype_initstates_keys)"
)
end
end
Expand Down
6 changes: 3 additions & 3 deletions src/utils/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,15 @@ function Base.show(io::IO, uh::AbstractHydrograph)
print(io, "inputs: ", isempty(uh.meta.inputs) ? "nothing" : join(uh.meta.inputs, ", "))
print(io, ", outputs: ", isempty(uh.meta.outputs) ? "nothing" : join(uh.meta.outputs, ", "))
print(io, ", params: ", isempty(uh.meta.params) ? "nothing" : join(uh.meta.params, ", "))
print(io, ", uhfunc: ", nameof(typeof(uh.uhfunc).parameters[1]))
print(io, ", uhfunc: ", typeof(uh.uhfunc).parameters[1])
print(io, ")")
else
println(io, "UnitHydroFlux:")
println(io, " Inputs: ", isempty(uh.meta.inputs) ? "nothing" : join(uh.meta.inputs, ", "))
println(io, " Outputs: ", isempty(uh.meta.outputs) ? "nothing" : join(uh.meta.outputs, ", "))
println(io, " Parameters: ", isempty(uh.meta.params) ? "nothing" : join(uh.meta.params, ", "))
println(io, " UnitFunction: ", nameof(typeof(uh.uhfunc).parameters[1]))
println(io, " SolveType: ", nameof(typeof(uh).parameters[end]))
println(io, " UnitFunction: ", typeof(uh.uhfunc).parameters[1])
println(io, " SolveType: ", typeof(uh).parameters[end])
end
end

Expand Down
37 changes: 4 additions & 33 deletions test/run_bucket.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@
@test Set(HydroModels.get_output_names(snow_ele)) == Set((:pet, :snowfall, :rainfall, :melt))
@test Set(HydroModels.get_state_names(snow_ele)) == Set((:snowpack,))
end

result = snow_ele(input, pas)
config = (timeidx=ts, solver=ManualSolver{true}())
result = snow_ele(input, pas, config=config)
ele_state_and_output_names = vcat(HydroModels.get_state_names(snow_ele), HydroModels.get_output_names(snow_ele))
result = NamedTuple{Tuple(ele_state_and_output_names)}(eachslice(result, dims=1))

@testset "test first output for hydro element" begin
snowpack0 = init_states[:snowpack]
pet0 = snow_funcs[1]([input_ntp.temp[1], input_ntp.lday[1]], ComponentVector(params=ComponentVector()))[1]
Expand All @@ -46,42 +47,12 @@
@test melt0 == result.melt[1]
end

@testset "test ode solved results" begin
prcp_itp = LinearInterpolation(input_ntp.prcp, ts)
temp_itp = LinearInterpolation(input_ntp.temp, ts)

function snowpack_bucket!(du, u, p, t)
snowpack_ = u[1]
Df, Tmax, Tmin = p.Df, p.Tmax, p.Tmin
prcp_, temp_ = prcp_itp(t), temp_itp(t)
snowfall_ = step_func(Tmin - temp_) * prcp_
melt_ = step_func(temp_ - Tmax) * step_func(snowpack_) * min(snowpack_, Df * (temp_ - Tmax))
du[1] = snowfall_ - melt_
end
prob = ODEProblem(snowpack_bucket!, [init_states.snowpack], (ts[1], ts[end]), params)
sol = solve(prob, Tsit5(), saveat=ts, reltol=1e-3, abstol=1e-3)
num_u = length(prob.u0)
manual_result = [sol[i, :] for i in 1:num_u]
ele_params_idx = [getaxes(pas[:params])[1][nm].idx for nm in HydroModels.get_param_names(snow_ele)]
paramfunc = (p) -> [p[:params][idx] for idx in ele_params_idx]

param_func, nn_param_func = HydroModels._get_parameter_extractors(snow_ele, pas)
itpfunc_list = map((var) -> LinearInterpolation(var, ts, extrapolate=true), eachrow(input))
ode_input_func = (t) -> [itpfunc(t) for itpfunc in itpfunc_list]
du_func = HydroModels._get_du_func(snow_ele, ode_input_func, param_func, nn_param_func)
solver = HydroModels.ODESolver(alg=Tsit5(), reltol=1e-3, abstol=1e-3)
initstates_mat = collect(pas[:initstates][HydroModels.get_state_names(snow_ele)])
#* solve the problem by call the solver
solved_states = solver(du_func, pas, initstates_mat, ts)
@test manual_result[1] == solved_states[1, :]
end

@testset "test all of the output" begin
param_func, nn_param_func = HydroModels._get_parameter_extractors(snow_ele, pas)
itpfunc_list = map((var) -> LinearInterpolation(var, ts, extrapolate=true), eachrow(input))
ode_input_func = (t) -> [itpfunc(t) for itpfunc in itpfunc_list]
du_func = HydroModels._get_du_func(snow_ele, ode_input_func, param_func, nn_param_func)
solver = HydroModels.ODESolver(alg=Tsit5(), reltol=1e-3, abstol=1e-3)
solver = ManualSolver{true}()
initstates_mat = collect(pas[:initstates][HydroModels.get_state_names(snow_ele)])
#* solve the problem by call the solver
snowpack_vec = solver(du_func, pas, initstates_mat, ts)[1, :]
Expand Down
20 changes: 0 additions & 20 deletions test/run_flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,26 +46,6 @@ end
@test HydroModels.get_state_names(state_flux_3) == [:d,]
end

# todo muskingum need rebuild
# @testset "test muskingum route flux" begin
# @variables q1

# # Building the Muskingum routing flux
# k, x = 3.0, 0.2
# pas = ComponentVector(params=(k=k, x=x,))
# msk_flux = HydroModels.MuskingumRouteFlux(q1)
# input = Float64[1 2 3 2 3 2 5 7 8 3 2 1]
# re = msk_flux(input, pas)

# # Verifying the input, output, and parameter names
# @test HydroModels.get_input_names(msk_flux) == [:q1]
# @test HydroModels.get_output_names(msk_flux) == [:q1_routed]
# @test HydroModels.get_param_names(msk_flux) == [:k, :x]

# # Checking the size and values of the output
# @test size(re) == size(input)
# @test re ≈ [1.0 0.977722 1.30086 1.90343 1.919 2.31884 2.15305 3.07904 4.39488 5.75286 4.83462 3.89097] atol = 1e-1
# end


@testset "test neural flux (single output)" begin
Expand Down
23 changes: 0 additions & 23 deletions test/run_groute.jl

This file was deleted.

Loading

0 comments on commit 35ff1bb

Please sign in to comment.