From 06cac1af5c9c2aee5c7cc85f68e5577fa7921935 Mon Sep 17 00:00:00 2001 From: Patrick Aschermayr Date: Mon, 10 Jul 2023 21:25:53 +0200 Subject: [PATCH] 0.5 --- Project.toml | 4 +- src/Core/Core.jl | 85 +++- src/Core/constrain/constrain.jl | 12 - src/Core/constrain/constraints/bijector.jl | 7 +- src/Core/constrain/constraints/constrained.jl | 2 +- src/Core/constrain/constraints/constraints.jl | 4 - src/Core/constrain/constraints/corrmatrix.jl | 126 ----- src/Core/constrain/constraints/covmatrix.jl | 128 ----- .../constrain/constraints/distribution.jl | 10 +- src/Core/constrain/constraints/fixed.jl | 2 +- src/Core/constrain/constraints/multi.jl | 2 +- src/Core/constrain/constraints/simplex.jl | 125 ----- .../constrain/constraints/unconstrained.jl | 2 +- src/Core/constrain/params.jl | 2 +- src/Core/constrain/transform.jl | 2 +- src/Core/flatten/construct.jl | 37 -- src/Core/flatten/nested/abstractarray.jl | 16 +- src/Core/flatten/nested/namedtuple.jl | 13 +- src/Core/flatten/nested/tuple.jl | 16 +- src/Core/flatten/types/float_cholesky.jl | 52 ++ src/Core/flatten/types/types.jl | 1 + src/Core/parameterinfo.jl | 64 ++- src/ModelWrappers.jl | 12 +- src/Models/Models.jl | 2 - src/Models/_soss.jl | 105 ---- src/Models/initial.jl | 2 +- src/Models/modelwrapper.jl | 30 +- src/Models/objective.jl | 7 +- src/Models/tagged.jl | 33 +- test/TestHelper.jl | 12 +- test/runtests.jl | 4 +- test/test-flatten.jl | 54 +- test/test-flatten/constraints.jl | 481 +++++++++++++++--- test/test-flatten/flatten.jl | 6 +- test/test-flatten/nested.jl | 31 +- test/test-flatten/types.jl | 64 ++- test/test-models.jl | 97 +++- test/test-objective.jl | 7 +- test/test-tagged.jl | 25 +- 39 files changed, 889 insertions(+), 795 deletions(-) delete mode 100644 src/Core/constrain/constraints/corrmatrix.jl delete mode 100644 src/Core/constrain/constraints/covmatrix.jl delete mode 100644 src/Core/constrain/constraints/simplex.jl create mode 100644 src/Core/flatten/types/float_cholesky.jl delete mode 100644 src/Models/_soss.jl diff --git a/Project.toml b/Project.toml index 62c6815..e301d4e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ModelWrappers" uuid = "44c54197-9f56-47cc-9960-7f2e20bfb0d6" authors = ["Patrick Aschermayr "] -version = "0.4.3" +version = "0.5.0" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" @@ -18,7 +18,7 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [compat] ArgCheck = "2" BaytesCore = "0.2" -Bijectors = "0.12" +Bijectors = "0.13" ChainRulesCore = "1" Distributions = "0.24, 0.25" DocStringExtensions = "0.8, 0.9" diff --git a/src/Core/Core.jl b/src/Core/Core.jl index bc4e14b..c1e4d8b 100644 --- a/src/Core/Core.jl +++ b/src/Core/Core.jl @@ -36,8 +36,73 @@ struct UnflattenStrict <: UnflattenTypes end struct UnflattenFlexible <: UnflattenTypes end ############################################################################################ -include("constrain/constrain.jl") +# Abstract supertypes for flatten/unflatten parameter + +""" + $(FUNCTIONNAME)(x ) +Convert 'x' into a Vector. + +# Examples +```julia +``` +""" +function flatten end + +""" + $(FUNCTIONNAME)(x ) +Convert 'x' into a Vector that is AD compatible. + +# Examples +```julia +``` +""" +function flattenAD end + +""" + $(FUNCTIONNAME)(x ) +Unflatten 'x' into original shape. + +# Examples +```julia +``` +""" +function unflatten end + +""" + $(FUNCTIONNAME)(x ) +Unflatten 'x' into original shape but keep type information of 'x' for AD compatibility. + +# Examples +```julia +``` +""" +function unflattenAD end + +############################################################################################ +# Abstract supertypes for constrain/unconstrain +""" +$(TYPEDEF) +Abstract super type for parameter constraints. +""" +abstract type AbstractConstraint end + +"Constrain `val` with given `constraint`" +function constrain end + +"Unconstrain `val` with given `constraint`" +function unconstrain end + +############################################################################################ +# Default Methods for unconstrain_flatten and unflatten_constrain +function unconstrain_flatten end +function unconstrain_flattenAD end + +function unflatten_constrain end +function unflattenAD_constrain end + +############################################################################################ include("flatten/flatten.jl") +include("constrain/constrain.jl") include("utility.jl") include("checks.jl") @@ -49,6 +114,22 @@ include("parameterinfo.jl") export FlattenTypes, FlattenAll, FlattenContinuous, + UnflattenTypes, UnflattenStrict, - UnflattenFlexible + UnflattenFlexible, + + flatten, + flattenAD, + + unflatten, + unflattenAD, + + AbstractConstraint, + constrain, + unconstrain, + + unconstrain_flatten, + unconstrain_flattenAD, + unflatten_constrain, + unflattenAD_constrain diff --git a/src/Core/constrain/constrain.jl b/src/Core/constrain/constrain.jl index 6844582..361ba62 100644 --- a/src/Core/constrain/constrain.jl +++ b/src/Core/constrain/constrain.jl @@ -1,17 +1,5 @@ ############################################################################################ #!NOTE: These are abstract super types needed if additional constraints are added to ModelWrappers. -""" -$(TYPEDEF) -Abstract super type for parameter constraints. -""" -abstract type AbstractConstraint end - -"Constrain `val` with given `constraint`" -function constrain(constraint::AbstractConstraint, val) end - -"Unconstrain `val` with given `constraint`" -function unconstrain(constraint::AbstractConstraint, val) end - "Compute log(abs(determinant(jacobian(`x`)))) for given transformer to unconstrained (!) domain." function log_abs_det_jac end diff --git a/src/Core/constrain/constraints/bijector.jl b/src/Core/constrain/constraints/bijector.jl index f19258f..3932630 100644 --- a/src/Core/constrain/constraints/bijector.jl +++ b/src/Core/constrain/constraints/bijector.jl @@ -13,7 +13,7 @@ end ############################################################################################ #= -2. Define functions to unconstrain(constraint, val) to unconstrained domain valᵤ, and a function constrain(constraint, valᵤ) back to val. +2.1 Define functions to unconstrain(constraint, val) to unconstrained domain valᵤ, and a function constrain(constraint, valᵤ) back to val. Dimensions of val and valᵤ should be the same, flattening will be handled separately. =# function unconstrain(bijection::Bijection, val) @@ -52,11 +52,12 @@ end function _check( _rng::Random.AbstractRNG, b::Bijection, - val::Union{R,Array{R},AbstractArray}, + val::Union{Cholesky, R,Array{R},AbstractArray}, ) where {R<:Real} - return typeof( unconstrain(b, val) ) == typeof(val) ? true : false + return true end + ############################################################################################ #Export export diff --git a/src/Core/constrain/constraints/constrained.jl b/src/Core/constrain/constraints/constrained.jl index 4c3940c..4d73d31 100644 --- a/src/Core/constrain/constraints/constrained.jl +++ b/src/Core/constrain/constraints/constrained.jl @@ -20,7 +20,7 @@ end ############################################################################################ #= -2. Define functions to unconstrain(constraint, val) to unconstrained domain valᵤ, and a function constrain(constraint, valᵤ) back to val. +2.1 Define functions to unconstrain(constraint, val) to unconstrained domain valᵤ, and a function constrain(constraint, valᵤ) back to val. Dimensions of val and valᵤ should be the same, flattening will be handled separately. =# function unconstrain(constrained::Constrained, val) diff --git a/src/Core/constrain/constraints/constraints.jl b/src/Core/constrain/constraints/constraints.jl index f5a5a67..6aca5bf 100644 --- a/src/Core/constrain/constraints/constraints.jl +++ b/src/Core/constrain/constraints/constraints.jl @@ -2,10 +2,6 @@ include("bijector.jl") include("distribution.jl") -include("simplex.jl") -include("corrmatrix.jl") -include("covmatrix.jl") - include("unconstrained.jl") include("constrained.jl") include("fixed.jl") diff --git a/src/Core/constrain/constraints/corrmatrix.jl b/src/Core/constrain/constraints/corrmatrix.jl deleted file mode 100644 index 14b7bd8..0000000 --- a/src/Core/constrain/constraints/corrmatrix.jl +++ /dev/null @@ -1,126 +0,0 @@ -############################################################################################ -# 1. Create a new Constraint, MyConstraint <: AbstractConstraint. -""" -$(TYPEDEF) - -Utility struct to help assign boundaries to parameter. - -# Fields -$(TYPEDFIELDS) -""" -struct CorrelationMatrix{B<:Bijection} <: AbstractConstraint - bijection::B - function CorrelationMatrix() - b = Bijection(Bijectors.CorrBijector()) - return new{typeof(b)}(b) - end -end - -############################################################################################ -#= -2. Define functions to unconstrain(constraint, val) to unconstrained domain valᵤ, and a function constrain(constraint, valᵤ) back to val. -Dimensions of val and valᵤ should be the same, flattening will be handled separately. -=# -function unconstrain(constraint::CorrelationMatrix, val) - return unconstrain(constraint.bijection, val) -end -function constrain(constraint::CorrelationMatrix, valᵤ) - return constrain(constraint.bijection, valᵤ) -end - -############################################################################################ -# 3. Optional - Check if check_transformer(constraint, val) works -#= -constraint = CorrelationMatrix() -val = [1. .2 ; .2 1.] -val_u = unconstrain(constraint, val) -val_o = constrain(constraint, val_u) -check_constraint(constraint, val) -=# - -############################################################################################ -# 4. If Objective is used, include a method that computes logabsdet from transformation to unconstrained domain. Same syntax as in Bijectors.jl package is used, i.e., -log_abs_det_jac is returned for computations. -function log_abs_det_jac(constraint::CorrelationMatrix, θ::T) where {T} - return log_abs_det_jac(constraint.bijection, θ) -end - -############################################################################################ -# 5. Add _check function to check for all other peculiar things that should be tested if new releases come out and add to Test section. -function _check( - _rng::Random.AbstractRNG, - constraint::CorrelationMatrix, - val::Matrix{R}, -) where {R<:Real} - ArgCheck.@argcheck all(LinearAlgebra.diag(val) .== 1.0) - return true -end - -############################################################################################ -# 6. Optionally - choose to only flatten upper non-diagonal parameter if Correlationmatrix is constraint -#= !NOTES: - Unconstrained will always be 0 everywhere except upper diagonal elements. All other entries do not matter for constrain/unconstrain. - Constrained will always have unit variance. -=# -function construct_flatten( - output::Type{T}, - flattentype::F, - unflattentype::UnflattenStrict, - constraint::C, - x::Matrix{R}, -) where { - T<:AbstractFloat, - F<:FlattenTypes, - R<:Real, - C<:Union{CorrelationMatrix, DistributionConstraint{<:Distributions.LKJ}, Distributions.LKJ, Bijectors.CorrBijector} -} - #!NOTE: CorrBijector seems to unconstrain to a Upper Diagonal Matrix - idx_upper = tag(x, true, false) - len = length(x) - len_unflat = sum(idx_upper) - #!NOTE: Buffer should be of type R, not T, as we want same type back afterwards - function flatten_CorrMatrix(x::AbstractMatrix{R}) where {R<:Real} - ArgCheck.@argcheck length(x) == len - return Vector{T}(flatten_Symmetric(x, idx_upper)) - end - buffer_unflat = ones(R, size(x)) - function CorrMatrix_from_vec(v::Union{<:Real,AbstractVector{<:Real}}) - ArgCheck.@argcheck length(v) == len_unflat - return Symmetric_from_flatten!(buffer_unflat, v, idx_upper) - end - return flatten_CorrMatrix, CorrMatrix_from_vec -end -function construct_flatten( - output::Type{T}, - flattentype::F, - unflattentype::UnflattenFlexible, - constraint::C, - x::Matrix{R}, -) where { - T<:AbstractFloat, - F<:FlattenTypes, - R<:Real, - C<:Union{CorrelationMatrix, DistributionConstraint{<:Distributions.LKJ},Distributions.LKJ, Bijectors.CorrBijector} -} - #!NOTE: CorrBijector seems to unconstrain to a Upper Diagonal Matrix - idx_upper = tag(x, true, false) - len = length(x) - len_unflat = sum(idx_upper) - function flatten_CorrMatrix_AD(x::AbstractMatrix{R}) where {R<:Real} - ArgCheck.@argcheck length(x) == len - return Vector{R}(flatten_Symmetric(x, idx_upper)) - end - function CorrMatrix_from_vec_AD(v::Union{<:Real,AbstractVector{<:Real}}) - ArgCheck.@argcheck length(v) == len_unflat - return Symmetric_from_flatten!(ones(eltype(v), size(x)), v, idx_upper) - end - return flatten_CorrMatrix_AD, CorrMatrix_from_vec_AD -end - -############################################################################################ -#Export -export - CorrelationMatrix, - construct_flatten, - constrain, - unconstrain, - log_abs_det_jac diff --git a/src/Core/constrain/constraints/covmatrix.jl b/src/Core/constrain/constraints/covmatrix.jl deleted file mode 100644 index 7d33ea1..0000000 --- a/src/Core/constrain/constraints/covmatrix.jl +++ /dev/null @@ -1,128 +0,0 @@ -############################################################################################ -# 1. Create a new Constraint, MyConstraint <: AbstractConstraint. -""" -$(TYPEDEF) - -Utility struct to help assign boundaries to parameter. - -# Fields -$(TYPEDFIELDS) -""" -struct CovarianceMatrix{B<:Bijection} <: AbstractConstraint - bijection::B - function CovarianceMatrix() - b = Bijection(Bijectors.PDBijector()) - return new{typeof(b)}(b) - end -end - -############################################################################################ -#= -2. Define functions to unconstrain(constraint, val) to unconstrained domain valᵤ, and a function constrain(constraint, valᵤ) back to val. -Dimensions of val and valᵤ should be the same, flattening will be handled separately. -=# -function unconstrain(constraint::CovarianceMatrix, val) - return unconstrain(constraint.bijection, val) -end -function constrain(constraint::CovarianceMatrix, valᵤ) - return constrain(constraint.bijection, valᵤ) -end - -############################################################################################ -# 3. Optional - Check if check_transformer(constraint, val) works -#= -b = Bijectors.PDBijector() -constraint = CovarianceMatrix() -val = [4. .8 ; .8 3.] -val_u = Bijectors.transform(b, val) -val_o = Bijectors.transform(inverse(b), val_u) -val_u = unconstrain(constraint, val) -val_o = constrain(constraint, val_u) -check_constraint(constraint, val) -=# - -############################################################################################ -# 4. If Objective is used, include a method that computes logabsdet from transformation to unconstrained domain. Same syntax as in Bijectors.jl package is used, i.e., -log_abs_det_jac is returned for computations. -function log_abs_det_jac(constraint::CovarianceMatrix, θ::T) where {T} - return log_abs_det_jac(constraint.bijection, θ) -end - -############################################################################################ -# 5. Add _check function to check for all other peculiar things that should be tested if new releases come out and add to Test section. -function _check( - _rng::Random.AbstractRNG, - constraint::CovarianceMatrix, - val::Matrix{R}, -) where {R<:Real} - ArgCheck.@argcheck LinearAlgebra.issymmetric(val) - return true -end - -############################################################################################ -# 6. Optionally - choose to only flatten upper non-diagonal parameter if Correlationmatrix is constraint -#!TODO: Works with flatten/unflatten - but constraint/unconstraint seems to deduce wrong type for ReverseDiff from Bijector - works fine with ForwardDiff/Zygote -#!NOTE: Bijectors map to lower triangular matrix while most AD libraries evaluate upper triangular matrices. -function construct_flatten( - output::Type{T}, - flattentype::F, - unflattentype::UnflattenStrict, - constraint::C, - x::Matrix{R}, -) where { - T<:AbstractFloat, - F<:FlattenTypes, - R<:Real, - C<:Union{CovarianceMatrix, DistributionConstraint{<:Distributions.InverseWishart}, Distributions.InverseWishart, Bijectors.PDBijector} -} - #!NOTE: PDBijector seems to unconstrain to a Lower Diagonal Matrix - idx_upper = tag(x, false, true) - len = length(x) - len_unflat = sum(idx_upper) - function flatten_CovMatrix(x::AbstractMatrix{R}) where {R<:Real} - ArgCheck.@argcheck length(x) == len - return Vector{T}(flatten_Symmetric(x, idx_upper)) - end - buffer_unflat = zeros(R, size(x)) - function CovMatrix_from_vec(v::Union{<:Real,AbstractVector{<:Real}}) - ArgCheck.@argcheck length(v) == len_unflat - return Symmetric_from_flatten!(buffer_unflat, v, idx_upper) - end - return flatten_CovMatrix, CovMatrix_from_vec -end -function construct_flatten( - output::Type{T}, - flattentype::F, - unflattentype::UnflattenFlexible, - constraint::C, - x::Matrix{R}, -) where { - T<:AbstractFloat, - F<:FlattenTypes, - R<:Real, - C<:Union{CovarianceMatrix, DistributionConstraint{<:Distributions.InverseWishart}, Distributions.InverseWishart, Bijectors.PDBijector} -} - #!NOTE: PDBijector seems to unconstrain to a Lower Diagonal Matrix - idx_upper = tag(x, false, true) - len = length(x) - len_unflat = sum(idx_upper) - function flatten_CovMatrix_AD(x::AbstractMatrix{R}) where {R<:Real} - ArgCheck.@argcheck length(x) == len - return Vector{R}(flatten_Symmetric(x, idx_upper)) - end - dims = size(x) - function CovMatrix_from_vecAD(v::Union{<:Real,AbstractVector{<:Real}}) - ArgCheck.@argcheck length(v) == len_unflat - buffer_unflat = zeros(eltype(v), dims) - return Symmetric_from_flatten!(buffer_unflat, v, idx_upper) - end - return flatten_CovMatrix_AD, CovMatrix_from_vecAD -end - -############################################################################################ -#Export -export - CovarianceMatrix, - construct_flatten, - constrain, - unconstrain, - log_abs_det_jac diff --git a/src/Core/constrain/constraints/distribution.jl b/src/Core/constrain/constraints/distribution.jl index 73ae7f8..6f2943b 100644 --- a/src/Core/constrain/constraints/distribution.jl +++ b/src/Core/constrain/constraints/distribution.jl @@ -19,7 +19,7 @@ Param(_rng::Random.AbstractRNG, constraint::A, val::B) where {A<:Distributions.D ############################################################################################ #= -2. Define functions to unconstrain(constraint, val) to unconstrained domain valᵤ, and a function constrain(constraint, valᵤ) back to val. +2.1 Define functions to unconstrain(constraint, val) to unconstrained domain valᵤ, and a function constrain(constraint, valᵤ) back to val. Dimensions of val and valᵤ should be the same, flattening will be handled separately. =# function unconstrain(dist::DistributionConstraint, val) @@ -50,20 +50,20 @@ end function _check( _rng::Random.AbstractRNG, d::DistributionConstraint, - val::Union{R,Array{R},AbstractArray}, + val::Union{Factorization, R, Array{R}, AbstractArray}, ) where {R<:Real} _val = rand(_rng, d.dist) return _check(_rng, d.bijection, val) && typeof(val) == typeof(_val) && size(val) == size(_val) ? true : false end ############################################################################################ -# 6. Optionally - Ignore non-specified distributions when flattening +# 6.1 Optionally - Ignore non-specified distributions when flattening function construct_flatten( output::Type{T}, flattentype::F, unflattentype::U, constraint::Distributions.Distribution, - x::Union{R,Array{R}}, + x::Union{Factorization, R, Array{R}}, ) where { T<:AbstractFloat, F<:FlattenTypes, @@ -73,6 +73,8 @@ function construct_flatten( return construct_flatten(T, flattentype, unflattentype, x) end +# 6.2 Implement custom flattening behavior for Bijectors + ############################################################################################ # Additional functionality -- that is not considered for other AbstractConstraints -- to evaluate prior logdensities etc that make logposterior definitions easier. function sample_constraint(_rng::Random.AbstractRNG, constraint::DistributionConstraint, val) diff --git a/src/Core/constrain/constraints/fixed.jl b/src/Core/constrain/constraints/fixed.jl index 11e8544..7e365ac 100644 --- a/src/Core/constrain/constraints/fixed.jl +++ b/src/Core/constrain/constraints/fixed.jl @@ -18,7 +18,7 @@ end ############################################################################################ #= -2. Define functions to unconstrain(constraint, val) to unconstrained domain valᵤ, and a function constrain(constraint, valᵤ) back to val. +2.1 Define functions to unconstrain(constraint, val) to unconstrained domain valᵤ, and a function constrain(constraint, valᵤ) back to val. Dimensions of val and valᵤ should be the same, flattening will be handled separately. =# function unconstrain(fixed::Fixed, val) diff --git a/src/Core/constrain/constraints/multi.jl b/src/Core/constrain/constraints/multi.jl index 6e40c7e..5c626fa 100644 --- a/src/Core/constrain/constraints/multi.jl +++ b/src/Core/constrain/constraints/multi.jl @@ -21,7 +21,7 @@ Param(_rng::Random.AbstractRNG, constraint::D, val::B) where {D<:Vector{<:Array{ ############################################################################################ #= -2. Define functions to unconstrain(constraint, val) to unconstrained domain valᵤ, and a function constrain(constraint, valᵤ) back to val. +2.1 Define functions to unconstrain(constraint, val) to unconstrained domain valᵤ, and a function constrain(constraint, valᵤ) back to val. Dimensions of val and valᵤ should be the same, flattening will be handled separately. =# function unconstrain(multi::MultiConstraint, val) diff --git a/src/Core/constrain/constraints/simplex.jl b/src/Core/constrain/constraints/simplex.jl deleted file mode 100644 index f9ee41b..0000000 --- a/src/Core/constrain/constraints/simplex.jl +++ /dev/null @@ -1,125 +0,0 @@ -############################################################################################ -# 1. Create a new Constraint, MyConstraint <: AbstractConstraint. -""" -$(TYPEDEF) - -Utility struct to help assign boundaries to parameter. - -# Fields -$(TYPEDFIELDS) -""" -struct Simplex{B<:Bijection} <: AbstractConstraint - len::Int64 - bijection::B - function Simplex(len::Integer) - ArgCheck.@argcheck len > 0 - b = Bijection(Bijectors.SimplexBijector{true}()) - return new{typeof(b)}(len, b) - end -end -Simplex(vec::AbstractVector) = Simplex(length(vec)) - -############################################################################################ -#= -2. Define functions to unconstrain(constraint, val) to unconstrained domain valᵤ, and a function constrain(constraint, valᵤ) back to val. -Dimensions of val and valᵤ should be the same, flattening will be handled separately. -=# -function unconstrain(constraint::Simplex, val) - return unconstrain(constraint.bijection, val) -end -function constrain(constraint::Simplex, valᵤ) - return constrain(constraint.bijection, valᵤ) -end - -############################################################################################ -# 3. Optional - Check if check_transformer(constraint, val) works -#= -constraint = Simplex(3) -val = [.2, .3, .5] -val_u = unconstrain(constraint, val) -val_o = constrain(constraint, val_u) -check_constraint(constraint, val) -=# - -############################################################################################ -# 4. If Objective is used, include a method that computes logabsdet from transformation to unconstrained domain. Same syntax as in Bijectors.jl package is used, i.e., -log_abs_det_jac is returned for computations. -function log_abs_det_jac(constraint::Simplex, θ::T) where {T} - return log_abs_det_jac(constraint.bijection, θ) -end - -############################################################################################ -# 5. Add _check function to check for all other peculiar things that should be tested if new releases come out and add to Test section. -function _check( - _rng::Random.AbstractRNG, - constraint::Simplex, - val::Vector{R}, -) where {R<:Real} - ArgCheck.@argcheck length(val) == constraint.len - ArgCheck.@argcheck all(val[iter] > 0.0 for iter in eachindex(val)) - return true -end - -############################################################################################ -# 6. Optionally - choose to only flatten k-1 parameter if Simplex is constraint -function construct_flatten( - output::Type{T}, - flattentype::F, - unflattentype::UnflattenStrict, - constraint::C, - x::Vector{R}, -) where { - T<:AbstractFloat, - F<:FlattenTypes, - R<:Real, - C<:Union{Simplex, DistributionConstraint{<:Distributions.Dirichlet}, Distributions.Dirichlet, Bijectors.SimplexBijector} -} - buffer_flat = zeros(T, length(x)-1) - len_flat = length(x) - len_unflat = len_flat-1 - function flatten_Simplex(x_vec::AbstractVector{R}) where {R<:Real} - ArgCheck.@argcheck length(x_vec) == len_flat - return fill_array!(buffer_flat, view(x_vec, 1:len_flat-1)) - end - buffer_unflat = zeros(R, length(x)) - function unflatten_Simplex(v::Union{R,AbstractVector{R}}) where {R<:Real} - ArgCheck.@argcheck length(v) == len_unflat - return Simplex_from_flatten!(buffer_unflat, v) - end - return flatten_Simplex, unflatten_Simplex -end - -function construct_flatten( - output::Type{T}, - flattentype::F, - unflattentype::UnflattenFlexible, - constraint::C, - x::Vector{R}, -) where { - T<:AbstractFloat, - F<:FlattenTypes, - R<:Real, - C<:Union{Simplex, DistributionConstraint{<:Distributions.Dirichlet}, Distributions.Dirichlet, Bijectors.SimplexBijector} -} - len_flat = length(x) - len_unflat = len_flat-1 - function flatten_Simplex_AD(x_vec::AbstractVector{R}) where {R<:Real} - ArgCheck.@argcheck length(x_vec) == len_flat - buffer = zeros(R, len_flat-1) - return fill_array!(buffer, view(x_vec, 1:len_flat-1)) - end - function unflatten_Simplex_AD(v::Union{R,AbstractVector{R}}) where {R<:Real} - ArgCheck.@argcheck length(v) == len_unflat - buffer = zeros(eltype(v), length(x)) - return Simplex_from_flatten!(buffer, v) - end - return flatten_Simplex_AD, unflatten_Simplex_AD -end - -############################################################################################ -#Export -export - Simplex, - construct_flatten, - constrain, - unconstrain, - log_abs_det_jac diff --git a/src/Core/constrain/constraints/unconstrained.jl b/src/Core/constrain/constraints/unconstrained.jl index d48668b..da88b51 100644 --- a/src/Core/constrain/constraints/unconstrained.jl +++ b/src/Core/constrain/constraints/unconstrained.jl @@ -18,7 +18,7 @@ end ############################################################################################ #= -2. Define functions to unconstrain(constraint, val) to unconstrained domain valᵤ, and a function constrain(constraint, valᵤ) back to val. +2.1 Define functions to unconstrain(constraint, val) to unconstrained domain valᵤ, and a function constrain(constraint, valᵤ) back to val. Dimensions of val and valᵤ should be the same, flattening will be handled separately. =# function unconstrain(unconstrained::Unconstrained, val) diff --git a/src/Core/constrain/params.jl b/src/Core/constrain/params.jl index d69aba8..b44acfd 100644 --- a/src/Core/constrain/params.jl +++ b/src/Core/constrain/params.jl @@ -12,7 +12,7 @@ function check_constraint(constraint::AbstractConstraint, val::V) where {V} valᵤ = unconstrain(constraint, val) #Constrain to original domain and compare vals valₒ = constrain(constraint, valᵤ) - _check = val ≈ valₒ + _check = Base.isapprox(val, valₒ) return _check end diff --git a/src/Core/constrain/transform.jl b/src/Core/constrain/transform.jl index 16598c6..a7c8894 100644 --- a/src/Core/constrain/transform.jl +++ b/src/Core/constrain/transform.jl @@ -34,4 +34,4 @@ end export TransformConstructor, constrain, unconstrain, - log_abs_det_jac + log_abs_det_jac \ No newline at end of file diff --git a/src/Core/flatten/construct.jl b/src/Core/flatten/construct.jl index f54064a..a63cdb1 100644 --- a/src/Core/flatten/construct.jl +++ b/src/Core/flatten/construct.jl @@ -114,15 +114,6 @@ function ReConstructor(constraint, x) end ############################################################################################ -""" - $(FUNCTIONNAME)(x ) -Convert 'x' into a Vector. - -# Examples -```julia -``` -""" -function flatten end function flatten(constructor::ReConstructor, x) return constructor.flatten.strict(x) end @@ -131,15 +122,6 @@ function flatten(x) return flatten(constructor, x), constructor end -""" - $(FUNCTIONNAME)(x ) -Convert 'x' into a Vector that is AD compatible. - -# Examples -```julia -``` -""" -function flattenAD end function flattenAD(constructor::ReConstructor, x) return constructor.flatten.flexible(x) end @@ -148,28 +130,9 @@ function flattenAD(x) return flattenAD(constructor, x), constructor end -""" - $(FUNCTIONNAME)(x ) -Unflatten 'x' into original shape. - -# Examples -```julia -``` -""" -function unflatten end function unflatten(constructor::ReConstructor, x) return constructor.unflatten.strict(x) end - -""" - $(FUNCTIONNAME)(x ) -Unflatten 'x' into original shape but keep type information of 'x' for AD compatibility. - -# Examples -```julia -``` -""" -function unflattenAD end function unflattenAD(constructor::ReConstructor, x) return constructor.unflatten.flexible(x) end diff --git a/src/Core/flatten/nested/abstractarray.jl b/src/Core/flatten/nested/abstractarray.jl index bc9668b..271b673 100644 --- a/src/Core/flatten/nested/abstractarray.jl +++ b/src/Core/flatten/nested/abstractarray.jl @@ -77,21 +77,7 @@ function _check( end ) end -#= -############################################################################################ -function construct_transform( - constraint::AbstractArray, - x::AbstractArray -) -## Obtain transform/inversetransform constructor for each element - x_transforms = map(x, constraint) do xᵢ, constraintᵢ - construct_transform(constraintᵢ, xᵢ) - end - _transform, _inversetransform = first.(x_transforms), last.(x_transforms) -## Return flatten/unflatten for AbstractArray - return _transform, _inversetransform -end -=# + ############################################################################################ function constrain( constraint::AbstractArray, diff --git a/src/Core/flatten/nested/namedtuple.jl b/src/Core/flatten/nested/namedtuple.jl index 373ff71..a0830fd 100644 --- a/src/Core/flatten/nested/namedtuple.jl +++ b/src/Core/flatten/nested/namedtuple.jl @@ -84,18 +84,7 @@ function _check( ) where {names1, names2} return _check(_rng, values(constraint), values(x)) end -#= -############################################################################################ -function construct_transform( - constraint::NamedTuple{names1}, - x::NamedTuple{names2} -) where {names1, names2} -## Obtain transform/inversetransform constructor for each element - _transform, _inversetransform = construct_transform(values(constraint), values(x)) -## Return flatten/unflatten for AbstractArray - return NamedTuple{names1}(_transform), NamedTuple{names1}(_inversetransform) -end -=# + ############################################################################################ function constrain( constraint::NamedTuple{names1}, diff --git a/src/Core/flatten/nested/tuple.jl b/src/Core/flatten/nested/tuple.jl index bcc04a3..e818b7c 100644 --- a/src/Core/flatten/nested/tuple.jl +++ b/src/Core/flatten/nested/tuple.jl @@ -83,21 +83,7 @@ function _check( end ) end -#= -############################################################################################ -function construct_transform( - constraint::Tuple, - x::Tuple -) -## Obtain transform/inversetransform constructor for each element - x_transforms = map(x, constraint) do xᵢ, constraintᵢ - construct_transform(constraintᵢ, xᵢ) - end - _transform, _inversetransform = first.(x_transforms), last.(x_transforms) -## Return flatten/unflatten for AbstractArray - return _transform, _inversetransform -end -=# + ############################################################################################ function constrain( constraint::Tuple, diff --git a/src/Core/flatten/types/float_cholesky.jl b/src/Core/flatten/types/float_cholesky.jl new file mode 100644 index 0000000..a11ab8d --- /dev/null +++ b/src/Core/flatten/types/float_cholesky.jl @@ -0,0 +1,52 @@ +function isapprox(x::Cholesky, y::Cholesky) + return isapprox(x.factors, y.factors) +end + +############################################################################################ +function construct_flatten( + output::Type{T}, + flattentype::F, + unflattentype::UnflattenStrict, + x::Cholesky +) where {T<:AbstractFloat,F<:FlattenTypes} + len_lower = binomial(size(x, 1), 2) + len = length(x.factors) + sz = size(x) + buffer_flat = zeros(T, len) + function flatten_to_Cholesky(v::Cholesky) + ArgCheck.@argcheck binomial(size(v, 1), 2) == len_lower + return fill_array!(buffer_flat, v.factors) + end + R = eltype(x.factors) + buffer_unflat = zeros(R, sz) + function unflatten_to_Cholesky(v::Union{R,AbstractVector{R}}) where {R<:Real} + ArgCheck.@argcheck length(v) == len + return Cholesky(fill_array!(buffer_unflat, v), 'L', 0) + end + return flatten_to_Cholesky, unflatten_to_Cholesky +end + +function construct_flatten( + output::Type{T}, + flattentype::F, + unflattentype::UnflattenFlexible, + x::Cholesky +) where {T<:AbstractFloat,F<:FlattenTypes} + len_lower = binomial(size(x, 1), 2) + len = length(x.factors) + sz = size(x) + function flatten_to_Cholesky(v::Cholesky) + ArgCheck.@argcheck binomial(size(v, 1), 2) == len_lower + return fill_array!(zeros(eltype(v.factors), len), v.factors) + end + function unflatten_to_Cholesky(v::Union{R,AbstractVector{R}}) where {R<:Real} + ArgCheck.@argcheck length(v) == len + return Cholesky(fill_array!(zeros(eltype(v), sz), v), 'L', 0) + end + return flatten_to_Cholesky, unflatten_to_Cholesky +end + +############################################################################################ +#Export +export + construct_flatten diff --git a/src/Core/flatten/types/types.jl b/src/Core/flatten/types/types.jl index 07524b0..8b5ecf9 100644 --- a/src/Core/flatten/types/types.jl +++ b/src/Core/flatten/types/types.jl @@ -2,6 +2,7 @@ include("float.jl") include("float_vector.jl") include("float_array.jl") +include("float_cholesky.jl") include("integer.jl") ############################################################################################ diff --git a/src/Core/parameterinfo.jl b/src/Core/parameterinfo.jl index d0d4e9e..869ffcc 100644 --- a/src/Core/parameterinfo.jl +++ b/src/Core/parameterinfo.jl @@ -6,19 +6,29 @@ Contains information about parameter distributions, transformations and constrai # Fields $(TYPEDFIELDS) """ -struct ParameterInfo{R<:ReConstructor,T<:TransformConstructor} +struct ParameterInfo{R<:ReConstructor, U<:ReConstructor, T<:TransformConstructor} "Contains information for flatten/unflatten parameter" reconstruct::R + "Contains information to reconstruct unconstrained parameter - important for non-bijective transformations" + reconstructᵤ::U "Contains information for constraining and unconstraining parameter." transform::T function ParameterInfo( - reconstruct::R, transform::T - ) where {R<:ReConstructor,T<:TransformConstructor} - return new{R, T}( - reconstruct, transform + reconstruct::R, reconstructᵤ::U, transform::T + ) where {R<:ReConstructor, U<:ReConstructor, T<:TransformConstructor} + return new{R, U, T}( + reconstruct, reconstructᵤ, transform ) end end +function ParameterInfo(flattendefault::D, constructor::R, transformer::T, val::V) where {D<:FlattenDefault, R<:ReConstructor, T<:TransformConstructor, V} + ## Construct flatten constructor for unconstrained parameterization - important for non-bijective transformations + constructorᵤ = ReConstructor(flattendefault, transformer.constraint, unconstrain(transformer, val)) + return ParameterInfo( + constructor, constructorᵤ, transformer + ) +end + function ParameterInfo( flattendefault::D, constraint::C, val::B ) where {D<:FlattenDefault, C<:NamedTuple, B<:NamedTuple} @@ -26,9 +36,11 @@ function ParameterInfo( constructor = ReConstructor(flattendefault, constraint, val) ## Assign transformer constraint NamedTuple transformer = TransformConstructor(constraint, val) + ## Construct flatten constructor for unconstrained parameterization - important for non-bijective transformations + constructorᵤ = ReConstructor(flattendefault, constraint, unconstrain(transformer, val)) ## Return ParameterInfo return ParameterInfo( - constructor, transformer + constructor, constructorᵤ, transformer ) end function ParameterInfo( @@ -39,19 +51,28 @@ function ParameterInfo( ## Split between values and constraints val = _get_val(parameter) constraint = _get_constraint(parameter) - ## Create flatten constructor - constructor = ReConstructor(flattendefault, constraint, val) - ## Assign transformer constraint NamedTuple - transformer = TransformConstructor(constraint, val) - ## Return ParameterInfo return ParameterInfo( - constructor, transformer + flattendefault, constraint, val ) end ############################################################################################ length(info::ParameterInfo) = info.reconstruct.unflatten.strict._unflatten.sz[end] +############################################################################################ +function flatten(info::ParameterInfo, x) + return flatten(info.reconstruct, x) +end +function flattenAD(info::ParameterInfo, x) + return flattenAD(info.reconstruct, x) +end +function unflatten(info::ParameterInfo, x) + return unflatten(info.reconstruct, x) +end +function unflattenAD(info::ParameterInfo, x) + return unflattenAD(info.reconstruct, x) +end + ############################################################################################ function constrain(info::ParameterInfo, valᵤ::V) where {V} return constrain(info.transform, valᵤ) @@ -63,20 +84,21 @@ function log_abs_det_jac(info::ParameterInfo, val::V) where {V} return log_abs_det_jac(info.transform, val) end -############################################################################################ -function flatten(info::ParameterInfo, x) - return flatten(info.reconstruct, x) +function unconstrain_flatten(info::ParameterInfo, val::V) where {V} + return flatten(info.reconstructᵤ, unconstrain(info.transform, val)) end -function flattenAD(info::ParameterInfo, x) - return flattenAD(info.reconstruct, x) +function unconstrain_flattenAD(info::ParameterInfo, val::V) where {V} + return flattenAD(info.reconstructᵤ, unconstrain(info.transform, val)) end -function unflatten(info::ParameterInfo, x) - return unflatten(info.reconstruct, x) + +function unflatten_constrain(info::ParameterInfo, valᵤ::V) where {V} + return constrain(info.transform, unflatten(info.reconstructᵤ, valᵤ)) end -function unflattenAD(info::ParameterInfo, x) - return unflattenAD(info.reconstruct, x) +function unflattenAD_constrain(info::ParameterInfo, valᵤ::V) where {V} + return constrain(info.transform, unflattenAD(info.reconstructᵤ, valᵤ)) end ############################################################################################ #export export ParameterInfo + diff --git a/src/ModelWrappers.jl b/src/ModelWrappers.jl index 66e057f..0e24e33 100644 --- a/src/ModelWrappers.jl +++ b/src/ModelWrappers.jl @@ -2,7 +2,7 @@ module ModelWrappers ############################################################################################ #Import External packages -import Base: Base, length, fill, fill!, print +import Base: Base, length, vec, fill, fill!, print, isapprox import StatsBase: StatsBase, sample, sample! import BaytesCore: BaytesCore, subset, update, generate_showvalues, generate using BaytesCore: @@ -23,7 +23,7 @@ using Random: Random, AbstractRNG, GLOBAL_RNG #!NOTE: These libraries are relevant for transform part using ChainRulesCore -using LinearAlgebra: LinearAlgebra, Diagonal, LowerTriangular, tril!, diag, issymmetric +using LinearAlgebra: LinearAlgebra, Factorization, Cholesky, Diagonal, LowerTriangular, tril!, diag, issymmetric using Distributions: Distributions, Distribution, logpdf using Bijectors: Bijectors, Bijector, logpdf_with_trans, transform, @@ -36,6 +36,10 @@ const max_val = 1e+100 "Smallest decrease allowed in the log objective results before tagged as divergent." const min_Δ = -1e+3 +function length_constrained end +function length_unconstrained end + + ############################################################################################ #Import include("Core/Core.jl") @@ -46,6 +50,8 @@ include("Models/Models.jl") export UpdateBool, UpdateTrue, - UpdateFalse + UpdateFalse, + length_constrained, + length_unconstrained end diff --git a/src/Models/Models.jl b/src/Models/Models.jl index bcabd4b..dbbed1b 100644 --- a/src/Models/Models.jl +++ b/src/Models/Models.jl @@ -9,7 +9,5 @@ include("objective.jl") include("initial.jl") include("predictive.jl") -#!NOTE: Remove Soss dependency from ModelWrappers because of heavy deps. Can make separate BaytesSoss later on. -#include("_soss.jl") ############################################################################################ # Export diff --git a/src/Models/_soss.jl b/src/Models/_soss.jl deleted file mode 100644 index 8e5bfc0..0000000 --- a/src/Models/_soss.jl +++ /dev/null @@ -1,105 +0,0 @@ -#= -############################################################################################ -using Soss: Soss, ConditionalModel -import Soss: Soss, predict, simulate - -############################################################################################ -"Wrapper functions to work with Soss Models" -#= -Note: Some improvements to work on: - -> In Objective (if multiple sampler are used to target only subset of parameter): - When Objective is formed, make posterior such that only tagged parameter are evaluated - -> would need to change old model.id in that case - -> Make Methods for: - simulate - generate - Predict -> predict(m(), (namedtuplevals)) - -> For Sequential Estimation methods (SMC and beyond) - - Need to update data in SOSS model separately (when created as input?) - - Need a way to separate log-prior and log-likelihood -=# - -############################################################################################ -# ModelWrapper part -""" -$(SIGNATURES) -Best guess for Soss Model parameter, excluding data and hyperparameter. Not exported. - -# Examples -```julia -``` - -""" -function _guess_soss_param(soss_posterior::M) where {M<:Soss.ConditionalModel} - ## All parameter from posterior.model.dists, except data (Hyperparameter should be fixed and have no model.dists entry) - param = setdiff(keys(soss_posterior.model.dists), keys(soss_posterior.obs)) - ## Check if all param Symbols have values assigned - ArgCheck.@argcheck all(haskey(soss_posterior.argvals, sym) for sym in param) "Not all posterior Soss model parameter have initial value assigned, please create initial value for all of them" - ## Return NamedTuple with initial parameter - return subset(soss_posterior.argvals, tuple(param...)) -end - -############################################################################################ -function ModelWrapper( - soss_posterior::M, flattendefault::F=FlattenDefault() -) where {M<:Soss.ConditionalModel,F<:FlattenDefault} - ## Check if Posterior Soss Model provided - ArgCheck.@argcheck !isempty(soss_posterior.obs) "No posterior Soss model provided, please use: posterior = MyModel | (MyDataName = MyDataValues,)" - ## Guess all parameter - val = _guess_soss_param(soss_posterior) - ## Create prior struct from val - _prior = subset(soss_posterior.model.dists, val) - _prior_eval = NamedTuple{keys(_prior)}(eval(_prior[sym]) for sym in keys(_prior)) - ## Create Param NamedTuple as Safetye check that all defined Soss parameter are consistent with Param syntax - #!NOTE: This is not ideal because all information was already there, but temporarily converting (val, prior) to Param struct guarantees that ModelWrapper can handle user input. - params = NamedTuple{keys(val)}( - Param(val[iter], _prior_eval[iter]) for iter in eachindex(val) - ) - ## Return ModelWrapper - return ModelWrapper(soss_posterior, params, flattendefault) -end - -#= -#!NOTE: Could use predict(posterior, vals) directly as it returns data dimension. However, simualate needed for any algorithms and definitons slightly different in packages, so we leave it blank for now. -function simulate(model::ModelWrapper{M}) where {M<:Soss.ConditionalModel} -end -=# - -############################################################################################ -# Objective part - -#!TODO: Make posterior such that only tagged parameter are evaluated -function Objective(model::ModelWrapper{M}) where {M<:Soss.ConditionalModel} - return Objective(model::ModelWrapper{M}, nothing, Tagged(model)) -end -function Objective( - model::ModelWrapper{M}, tagged::T -) where {M<:Soss.ConditionalModel,T<:Tagged} - return Objective(model::ModelWrapper{M}, nothing, tagged) -end - -function (objective::Objective{<:ModelWrapper{M}})( - θ::NamedTuple -) where {M<:Soss.ConditionalModel} - return Soss.logdensity(objective.model.id(θ)) -end - -#!TODO: Need a way to update logposterior (model.id) with new data -#= - -=# - -#= -#NOTE: Waiting for: https://github.com/cscherrer/Soss.jl/issues/301 -function predict(_rng::Random.AbstractRNG, objective::Objective{<:ModelWrapper{M}}) where {M<:Soss.ConditionalModel} - return nothing -end -function generate(_rng::Random.AbstractRNG, objective::Objective{<:ModelWrapper{M}}) where {M<:Soss.ConditionalModel} - return nothing -end -=# - -############################################################################################ -# Export -export ModelWrapper -=# diff --git a/src/Models/initial.jl b/src/Models/initial.jl index 7800837..94f16b0 100644 --- a/src/Models/initial.jl +++ b/src/Models/initial.jl @@ -43,7 +43,7 @@ function sample(_rng::Random.AbstractRNG, initialization::PriorInitialization, k while !isfinite(ℓθᵤ) && counter <= Ntrials counter += 1 θ = sample(_rng, objective.model, objective.tagged) - ℓθᵤ = objective(flatten(objective.model.info.reconstruct, unconstrain(objective.model.info.transform, θ))) + ℓθᵤ = objective( unconstrain_flatten(objective.model.info, θ) ) end ArgCheck.@argcheck counter <= Ntrials "Could not find initial parameter with finite log target density. Adjust intial values, prior, or increase number of intial samples." return θ diff --git a/src/Models/modelwrapper.jl b/src/Models/modelwrapper.jl index 2b95c74..8372fa0 100644 --- a/src/Models/modelwrapper.jl +++ b/src/Models/modelwrapper.jl @@ -55,7 +55,9 @@ ModelWrapper(parameter::A, arg::C=(;), flattendefault::F=FlattenDefault()) where ############################################################################################ # Basic functions for Model struct -length(model::ModelWrapper) = model.info.reconstruct.unflatten.strict._unflatten.sz[end] +length_constrained(model::ModelWrapper) = model.info.reconstruct.unflatten.strict._unflatten.sz[end] +length_unconstrained(model::ModelWrapper) = model.info.reconstructᵤ.unflatten.strict._unflatten.sz[end] + paramnames(model::ModelWrapper) = keys(model.val) ############################################################################################ @@ -134,6 +136,19 @@ function unconstrain(model::ModelWrapper) return unconstrain(model.info, model.val) end +""" +$(SIGNATURES) +Constrain 'θᵤ' values with model.info ParameterInfo. + +# Examples +```julia +``` + +""" +function constrain(model::ModelWrapper, θ::NamedTuple) + return constrain(model.info, θ) +end + """ $(SIGNATURES) Flatten 'model' values and return as vector. @@ -157,7 +172,10 @@ Flatten and unconstrain 'model' values and return as vector. """ function unconstrain_flatten(model::ModelWrapper) - return flatten(model.info, unconstrain(model)) + return unconstrain_flatten(model.info, model.val) +end +function unconstrain_flattenAD(model::ModelWrapper) + return unconstrain_flattenAD(model.info, model.val) end ######################################### @@ -198,7 +216,10 @@ Constrain and Unflatten vector 'θᵤ' given 'model' constraints. """ function unflatten_constrain(model::ModelWrapper, θᵤ::AbstractVector{T}) where {T<:Real} - return constrain(model.info, unflatten(model, θᵤ)) + return unflatten_constrain(model.info, θᵤ) +end +function unflattenAD_constrain(model::ModelWrapper, θᵤ::AbstractVector{T}) where {T<:Real} + return unflattenAD_constrain(model.info, θᵤ) end """ @@ -317,7 +338,8 @@ export BaseModel, ModelWrapper, simulate, - length, + length_constrained, + length_unconstrained, paramnames, fill, fill!, diff --git a/src/Models/objective.jl b/src/Models/objective.jl index 847dbb6..0ec025a 100644 --- a/src/Models/objective.jl +++ b/src/Models/objective.jl @@ -40,7 +40,10 @@ end ############################################################################################ # Basic functions for Model struct -length(objective::Objective) = length(objective.tagged) +length_constrained(objective::Objective) = length_constrained(objective.tagged) +length_unconstrained(objective::Objective) = length_unconstrained(objective.tagged) + + paramnames(objective::Objective) = paramnames(objective.tagged) ############################################################################################ @@ -123,7 +126,7 @@ function (objective::Objective)(θᵤ::AbstractVector{T}, arg::A = objective.mod @unpack model, tagged, temperature = objective ## Convert vector θᵤ back to constrained space as NamedTuple #!NOTE: This allocates new NamedTuple only once - using a constrain!(buffer, ...) does not improve performance wrt allocations - θ = constrain(tagged.info, unflattenAD(tagged.info, θᵤ)) + θ = unflattenAD_constrain(tagged.info, θᵤ) #!NOTE: There are border cases where θᵤ is still finite, but θ no longer after transformation, so have to cover this separately _checkfinite(θ) || return -Inf ## logabsdet_jac for transformations diff --git a/src/Models/tagged.jl b/src/Models/tagged.jl index fa003e3..1983981 100644 --- a/src/Models/tagged.jl +++ b/src/Models/tagged.jl @@ -36,7 +36,9 @@ end ############################################################################################ # Basic functions for Tagged struct -length(tagged::Tagged) = tagged.info.reconstruct.unflatten.strict._unflatten.sz[end] +length_constrained(tagged::Tagged) = tagged.info.reconstruct.unflatten.strict._unflatten.sz[end] +length_unconstrained(tagged::Tagged) = tagged.info.reconstructᵤ.unflatten.strict._unflatten.sz[end] + paramnames(tagged::Tagged) = keys(tagged.parameter) #A convenient method for evaluating a prior distribution of a NamedTuple parameter @@ -66,11 +68,29 @@ end function unconstrain(model::ModelWrapper, tagged::Tagged) return unconstrain(tagged.info, subset(model, tagged)) end + +""" +$(SIGNATURES) +Constrain 'θᵤ' values with tagged.info ParameterInfo. + +# Examples +```julia +``` + +""" +function constrain(model::ModelWrapper, tagged::Tagged, θ::NamedTuple) + return constrain(tagged.info, θ) +end + function flatten(model::ModelWrapper, tagged::Tagged) return flatten(tagged.info, subset(model, tagged)) end + function unconstrain_flatten(model::ModelWrapper, tagged::Tagged) - return flatten(tagged.info, unconstrain(model, tagged)) + return unconstrain_flatten(tagged.info, subset(model, tagged)) +end +function unconstrain_flattenAD(model::ModelWrapper, tagged::Tagged) + return unconstrain_flattenAD(tagged.info, subset(model, tagged)) end ######################################### @@ -85,10 +105,16 @@ function unflatten!( model.val = merge(model.val, unflatten(model, tagged, θ)) return nothing end + function unflatten_constrain( model::ModelWrapper, tagged::Tagged, θᵤ::AbstractVector{T} ) where {T<:Real} - return constrain(tagged.info, unflatten(model, tagged, θᵤ)) + return unflatten_constrain(tagged.info, θᵤ) +end +function unflattenAD_constrain( + model::ModelWrapper, tagged::Tagged, θᵤ::AbstractVector{T} +) where {T<:Real} + return unflattenAD_constrain(tagged.info, θᵤ) end function unflatten_constrain!( model::ModelWrapper, tagged::Tagged, θᵤ::AbstractVector{T} @@ -127,7 +153,6 @@ end ############################################################################################ export Tagged, - length, fill, fill!, subset, diff --git a/test/TestHelper.jl b/test/TestHelper.jl index ee9db28..058a98a 100644 --- a/test/TestHelper.jl +++ b/test/TestHelper.jl @@ -145,7 +145,7 @@ val_dist = ( ], [[i / 10 for i in Base.OneTo(3)] for _ in Base.OneTo(4)], ), -) +); #val_dist_length = 58 ############################################################################################ @@ -162,7 +162,7 @@ val_dist_nested = (; e=val_dist.d_a2, f=val_dist.d_e2, ), -) +); ############################################################################################ # Non-Probabilistic Parameter (Experimental!): @@ -210,7 +210,7 @@ val_constrained = ( ## Constrained Parameter -> Scalar only con_a1=Param(Constrained(0.0, 2.0), 1.0, ), con_b1=Param(Constrained(Float32(20.0), Float32(30.0)), Float32(26.0), ), -) +); ################################################################################ # Parameter for an example Model @@ -237,7 +237,7 @@ _val_examplemodel = ( σ4=Param([[_σ, _σ], [_σ, _σ]], [[5.0, 5.0], [5.0, 5.0]]), ρ4=Param([Distributions.LKJ(2, 1.0), Distributions.LKJ(2, 1.0)], [copy(_ρ), copy(_ρ)], ), p=Param(Distributions.Dirichlet(3, 3.0), [0.2, 0.3, 0.5],), -) +); #_val_examplemodel_length = 23 ################################################################################ @@ -278,5 +278,5 @@ _val_lowerdims = (; # ldim_j1 = Param([15. .16 ; .16 18.], _iwish2), # ldim_j2 = Param([[19. .20 ; .20 22.], [23. .24 ; .24 26.]],[_iwish2, _iwish2]), # ldim_j3 = Param([15. .16 .16 ; .16 18. .16 ; .16 .16 20.], _iwish3), - # ldim_j4 = Param([[15. .16 .16 ; .16 18. .16 ; .16 .16 20.], [15. .16 .16 ; .16 18. .16 ; .16 .16 20.]],[_iwish3, _iwish3]), -) + # ldim_j4 = Param([[15. .16 .16; ; .16 18. .16 ; .16 .16 20.], [15. .16 .16 ; .16 18. .16 ; .16 .16 20.]],[_iwish3, _iwish3]), +); diff --git a/test/runtests.jl b/test/runtests.jl index fc127c8..e44ab26 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,7 +5,7 @@ using Random: Random, AbstractRNG, seed! #using Soss using LinearAlgebra -using Distributions, Bijectors, DistributionsAD +using Distributions, Bijectors#, DistributionsAD using ForwardDiff, ReverseDiff, Zygote, Enzyme using ArgCheck @@ -46,7 +46,7 @@ import ModelWrappers: simulate ############################################################################################ # Include Files -include("TestHelper.jl") +include("TestHelper.jl"); ############################################################################################ # Run Tests diff --git a/test/test-flatten.jl b/test/test-flatten.jl index eabe021..bbc4be7 100644 --- a/test/test-flatten.jl +++ b/test/test-flatten.jl @@ -7,7 +7,7 @@ _params = merge(val_dist, val_dist_nested) ## Iterate trough all Params in TestHelper.jl file for sym in eachindex(_params) - # println(sym) + #println(sym) param = _params[sym] θ = _get_val(param) constraint = _get_constraint(param) @@ -58,25 +58,31 @@ @test typeof(θ) == typeof(θ_constrained) ## Type Check 3 - size of flatten(constrained) == flatten(unconstrained) for current Bijectors _θ_vec1 = _flatten(θ) - _θ_vec2 = _flatten(θ_unconstrained) - @test length(_θ_vec1) == length(_θ_vec2) - @test typeof(_unflatten(_θ_vec1)) == typeof(_unflatten(_θ_vec2)) + + #Note: No longer valid as of Bijectors 0.13 + # _θ_vec2 = _flatten(θ_unconstrained) + #@test length(_θ_vec1) == length(_θ_vec2) + #@test typeof(_unflatten(_θ_vec1)) == typeof(_unflatten(_θ_vec2)) ## If applicable, check if gradients for supported AD frameworks can be computed if length(θ_flat) > 0 + reconstruct = ReConstructor(constraint, θ) + transform = TransformConstructor(constraint, θ) + info = ParameterInfo(FlattenDefault(), reconstruct, transform, θ) + θ_flat_unconstrained = unconstrain_flatten(info, θ) function check_AD_closure(constraint, val) reconstruct = ReConstructor(constraint, val) - transformer = TransformConstructor(constraint, val) + transform = TransformConstructor(constraint, val) + info = ParameterInfo(FlattenDefault(), reconstruct, transform, val) function check_AD(θₜ::AbstractVector{T}) where {T<:Real} - θ_temp = unflattenAD(reconstruct, θₜ) - θ = constrain(transformer, θ_temp) + θ = unflattenAD_constrain(info, θₜ) return log_prior(constraint, θ) + log_abs_det_jac(transformer, θ) end end check_AD = check_AD_closure(constraint, θ) - check_AD(θ_flat) - grad_mod_fd = ForwardDiff.gradient(check_AD, θ_flat) - grad_mod_rd = ReverseDiff.gradient(check_AD, θ_flat) - grad_mod_zy = Zygote.gradient(check_AD, θ_flat)[1] + check_AD(θ_flat_unconstrained) + grad_mod_fd = ForwardDiff.gradient(check_AD, θ_flat_unconstrained) + grad_mod_rd = ReverseDiff.gradient(check_AD, θ_flat_unconstrained) + grad_mod_zy = Zygote.gradient(check_AD, θ_flat_unconstrained)[1] @test sum(abs.(grad_mod_fd - grad_mod_rd)) ≈ 0 atol = _TOL @test sum(abs.(grad_mod_fd - grad_mod_zy)) ≈ 0 atol = _TOL end @@ -136,27 +142,31 @@ end θ_constrained = constrain(transformer, θ_unconstrained) @test typeof(θ) == typeof(θ_constrained) ## Type Check 3 - size of flatten(constrained) == flatten(unconstrained) for current Bijectors - _θ_vec1 = _flatten(θ) - _θ_vec2 = _flatten(θ_unconstrained) - @test length(_θ_vec1) == length(_θ_vec2) - @test typeof(_unflatten(_θ_vec1)) == typeof(_unflatten(_θ_vec2)) +# _θ_vec1 = _flatten(θ) +# _θ_vec2 = _flatten(θ_unconstrained) +# @test length(_θ_vec1) == length(_θ_vec2) +# @test typeof(_unflatten(_θ_vec1)) == typeof(_unflatten(_θ_vec2)) ## If applicable, check if gradients for supported AD frameworks can be computed if length(θ_flat) > 0 + reconstruct = ReConstructor(constraint, θ) + transform = TransformConstructor(constraint, θ) + info = ParameterInfo(FlattenDefault(), reconstruct, transform, θ) + θ_flat_unconstrained = unconstrain_flatten(info, θ) function check_AD_closure(constraint, val) reconstruct = ReConstructor(constraint, val) - transformer = TransformConstructor(constraint, val) + transform = TransformConstructor(constraint, val) + info = ParameterInfo(FlattenDefault(), reconstruct, transform, val) function check_AD(θₜ::AbstractVector{T}) where {T<:Real} - θ_temp = unflattenAD(reconstruct, θₜ) - θ = constrain(transformer, θ_temp) + θ = unflattenAD_constrain(info, θₜ) return log_prior(constraint, θ) + log_abs_det_jac(transformer, θ) end end check_AD = check_AD_closure(constraint, θ) - check_AD(θ_flat) - grad_mod_fd = ForwardDiff.gradient(check_AD, θ_flat) - grad_mod_rd = ReverseDiff.gradient(check_AD, θ_flat) + check_AD(θ_flat_unconstrained) + grad_mod_fd = ForwardDiff.gradient(check_AD, θ_flat_unconstrained) + grad_mod_rd = ReverseDiff.gradient(check_AD, θ_flat_unconstrained) #!NOTE: Zygote would just record "Nothing" as gradient for Fixed/Unconstrained without a functor - #grad_mod_zy = Zygote.gradient(check_AD, θ_flat)[1] + #grad_mod_zy = Zygote.gradient(check_AD, θ_flat_unconstrained)[1] @test sum(abs.(grad_mod_fd - grad_mod_rd)) ≈ 0 atol = _TOL #@test sum(abs.(grad_mod_fd - grad_mod_zy)) ≈ 0 atol = _TOL end diff --git a/test/test-flatten/constraints.jl b/test/test-flatten/constraints.jl index 3cefe4f..1d01354 100644 --- a/test/test-flatten/constraints.jl +++ b/test/test-flatten/constraints.jl @@ -10,6 +10,31 @@ constraint = Bijection(Bijectors.bijector(Gamma(2,2))) ReConstructor(constraint, val) reconstruct = ReConstructor(flatdefault, constraint, val) + + reconstruct = ReConstructor(flatdefault, constraint, val) + transform = TransformConstructor(constraint, val) + info = ParameterInfo(flatdefault, reconstruct, transform, val) + + val_flat = flatten(info, val) + val_unflat = unflatten(info, val_flat) + @test val_unflat == val + + val_unconstrained = unconstrain(info, val) + val_constrained = constrain(info, val_unconstrained) + @test val_constrained == val + + val_flat_unconstrained = unconstrain_flatten(info, val) + val_unflat_constrained = unflatten_constrain(info, val_flat_unconstrained) + @test val_unflat_constrained ≈ val + + check_AD = check_AD_closure(constraint, val) + check_AD(val_flat_unconstrained) + grad_mod_fd = ForwardDiff.gradient(check_AD, val_flat_unconstrained) + + grad_mod_rd = ReverseDiff.gradient(check_AD, val_flat_unconstrained) + grad_mod_zy = Zygote.gradient(check_AD, val_flat_unconstrained)[1] + + # Flatten x_flat = flatten(reconstruct, val) @test x_flat isa AbstractVector @@ -51,6 +76,30 @@ end constraint = Gamma(2,2) ReConstructor(constraint, val) reconstruct = ReConstructor(flatdefault, constraint, val) + + reconstruct = ReConstructor(flatdefault, constraint, val) + transform = TransformConstructor(DistributionConstraint(constraint), val) + info = ParameterInfo(flatdefault, reconstruct, transform, val) + + val_flat = flatten(info, val) + val_unflat = unflatten(info, val_flat) + @test val_unflat == val + + val_unconstrained = unconstrain(info, val) + val_constrained = constrain(info, val_unconstrained) + @test val_constrained == val + + val_flat_unconstrained = unconstrain_flatten(info, val) + val_unflat_constrained = unflatten_constrain(info, val_flat_unconstrained) + @test val_unflat_constrained ≈ val + + check_AD = check_AD_closure(DistributionConstraint(constraint), val) + check_AD(val_flat_unconstrained) + grad_mod_fd = ForwardDiff.gradient(check_AD, val_flat_unconstrained) + + grad_mod_rd = ReverseDiff.gradient(check_AD, val_flat_unconstrained) + grad_mod_zy = Zygote.gradient(check_AD, val_flat_unconstrained)[1] + # Flatten x_flat = flatten(reconstruct, val) @test x_flat isa AbstractVector @@ -131,6 +180,30 @@ end constraint = Constrained(1.,3.) ReConstructor(constraint, val) reconstruct = ReConstructor(flatdefault, constraint, val) + + reconstruct = ReConstructor(flatdefault, constraint, val) + transform = TransformConstructor(constraint, val) + info = ParameterInfo(flatdefault, reconstruct, transform, val) + + val_flat = flatten(info, val) + val_unflat = unflatten(info, val_flat) + @test val_unflat == val + + val_unconstrained = unconstrain(info, val) + val_constrained = constrain(info, val_unconstrained) + @test val_constrained == val + + val_flat_unconstrained = unconstrain_flatten(info, val) + val_unflat_constrained = unflatten_constrain(info, val_flat_unconstrained) + @test val_unflat_constrained == val + + check_AD = check_AD_closure(constraint, val) + check_AD(val_flat_unconstrained) + grad_mod_fd = ForwardDiff.gradient(check_AD, val_flat_unconstrained) + + grad_mod_rd = ReverseDiff.gradient(check_AD, val_flat_unconstrained) + grad_mod_zy = Zygote.gradient(check_AD, val_flat_unconstrained)[1] + # Flatten x_flat = flatten(reconstruct, val) @test x_flat isa AbstractVector @@ -172,6 +245,30 @@ end constraint = Unconstrained() ReConstructor(constraint, val) reconstruct = ReConstructor(flatdefault, constraint, val) + + reconstruct = ReConstructor(flatdefault, constraint, val) + transform = TransformConstructor(constraint, val) + info = ParameterInfo(flatdefault, reconstruct, transform, val) + + val_flat = flatten(info, val) + val_unflat = unflatten(info, val_flat) + @test val_unflat == val + + val_unconstrained = unconstrain(info, val) + val_constrained = constrain(info, val_unconstrained) + @test val_constrained == val + + val_flat_unconstrained = unconstrain_flatten(info, val) + val_unflat_constrained = unflatten_constrain(info, val_flat_unconstrained) + @test val_unflat_constrained == val + + check_AD = check_AD_closure(constraint, val) + check_AD(val_flat_unconstrained) + grad_mod_fd = ForwardDiff.gradient(check_AD, val_flat_unconstrained) + + grad_mod_rd = ReverseDiff.gradient(check_AD, val_flat_unconstrained) + grad_mod_zy = Zygote.gradient(check_AD, val_flat_unconstrained)[1] + # Flatten x_flat = flatten(reconstruct, val) @test x_flat isa AbstractVector @@ -213,6 +310,29 @@ end constraint = Fixed() ReConstructor(constraint, val) reconstruct = ReConstructor(flatdefault, constraint, val) + + reconstruct = ReConstructor(flatdefault, constraint, val) + transform = TransformConstructor(constraint, val) + info = ParameterInfo(flatdefault, reconstruct, transform, val) + + val_flat = flatten(info, val) + val_unflat = unflatten(info, val_flat) + @test val_unflat == val + + val_unconstrained = unconstrain(info, val) + val_constrained = constrain(info, val_unconstrained) + @test val_constrained == val + + val_flat_unconstrained = unconstrain_flatten(info, val) + val_unflat_constrained = unflatten_constrain(info, val_flat_unconstrained) + @test val_unflat_constrained == val + + check_AD = check_AD_closure(constraint, val) + check_AD(val_flat_unconstrained) + grad_mod_fd = ForwardDiff.gradient(check_AD, val_flat_unconstrained) + + grad_mod_rd = ReverseDiff.gradient(check_AD, val_flat_unconstrained) + grad_mod_zy = Zygote.gradient(check_AD, val_flat_unconstrained)[1] # Flatten x_flat = flatten(reconstruct, val) @test x_flat isa AbstractVector @@ -257,25 +377,53 @@ end val[3,1] = val[1, 3] = 0.13 val[3,2] = val[2, 3] = 0.14 val - constraint = CorrelationMatrix() + constraint = Bijection( Bijectors.bijector(LKJ(3,1)) ) ReConstructor(constraint, val) reconstruct = ReConstructor(flatdefault, constraint, val) + + reconstruct = ReConstructor(flatdefault, constraint, val) + transform = TransformConstructor(constraint, val) + info = ParameterInfo(flatdefault, reconstruct, transform, val) + + val_flat = flatten(info, val) + val_unflat = unflatten(info, val_flat) + @test sum(val_unflat .- val) ≈ 0 atol = _TOL + + val_unconstrained = unconstrain(info, val) + val_constrained = constrain(info, val_unconstrained) + @test sum(val_constrained .- val) ≈ 0 atol = _TOL + + val_flat_unconstrained = unconstrain_flatten(info, val) + val_unflat_constrained = unflatten_constrain(info, val_flat_unconstrained) + @test sum(val_unflat_constrained .- val) ≈ 0 atol = _TOL + + check_AD = check_AD_closure(constraint, val) + check_AD(val_flat_unconstrained) + grad_mod_fd = ForwardDiff.gradient(check_AD, val_flat_unconstrained) + + grad_mod_rd = ReverseDiff.gradient(check_AD, val_flat_unconstrained) + grad_mod_zy = Zygote.gradient(check_AD, val_flat_unconstrained)[1] + # Flatten - x_flat = flatten(reconstruct, val) + x_flat_all = flatten(reconstruct, val) + x_constrained = unconstrain(constraint, val) + + x_flat = unconstrain_flatten(info, val) @test x_flat isa AbstractVector - @test eltype(x_flat) == output + @test eltype(x_flat_all) == output @test length(x_flat) == 3 - @test x_flat == output.([0.12, 0.13, 0.14]) -# Flatten AD + @test x_flat ≈ x_constrained #output.([0.12, 0.13, 0.14]) + + # Flatten AD x_flatAD = flattenAD(reconstruct, val) @test x_flatAD isa AbstractVector @test eltype(x_flatAD) == eltype(val) # Unflatten - x_unflat = unflatten(reconstruct, x_flat) + x_unflat = unflatten(reconstruct, x_flat_all) @test x_unflat isa typeof(val) # Unflatten AD - x_unflatAD = unflattenAD(reconstruct, x_flat) - @test eltype(x_unflatAD) == eltype(x_flat) + x_unflatAD = unflattenAD(reconstruct, x_flat_all) + @test eltype(x_unflatAD) == eltype(x_flat_all) x_unflatAD2 = unflattenAD(reconstruct, x_flatAD) @test eltype(x_unflatAD2) == eltype(x_flatAD) @@ -299,11 +447,11 @@ end @test val_con ≈ val @test val_con isa typeof(val) - @test val_uncon isa typeof(val) + # @test val_uncon isa typeof(val) @test logabs isa AbstractFloat @test ModelWrappers._check(_RNG, constraint, val) #Upper triangular - @test val_uncon[3,1] == 0.0 +# @test val_uncon[3,1] == 0.0 ################################################################################ # Check if Cor distribution also defaults to constraint @@ -322,17 +470,121 @@ end x_flat = flatten(reconstruct, val) @test x_flat isa AbstractVector @test eltype(x_flat) == output - @test length(x_flat) == 3 - @test x_flat == output.([0.12, 0.13, 0.14]) +# @test length(x_flat) == 3 +# @test x_flat == output.([0.12, 0.13, 0.14]) constraint = DistributionConstraint(con) ReConstructor(constraint, val) reconstruct = ReConstructor(flatdefault, constraint, val) x_flat = flatten(reconstruct, val) @test x_flat isa AbstractVector - @test eltype(x_flat) == output +# @test eltype(x_flat) == output +# @test length(x_flat) == 3 +# @test x_flat == output.([0.12, 0.13, 0.14]) + end + end +end + +@testset "Constraints - Cholesky LKJ" begin + for output in outputtypes + for flattentype in flattentypes + flatdefault = FlattenDefault(; output = output, flattentype = flattentype) + _dist = LKJCholesky(3,1) + constraint = Bijection( Bijectors.bijector(_dist) ) + val = rand(_RNG, _dist) + ReConstructor(constraint, val) + reconstruct = ReConstructor(flatdefault, constraint, val) + + reconstruct = ReConstructor(flatdefault, constraint, val) + transform = TransformConstructor(constraint, val) + info = ParameterInfo(flatdefault, reconstruct, transform, val) + + val_flat = flatten(info, val) + val_unflat = unflatten(info, val_flat) + @test sum(val_unflat.factors .- val.factors) ≈ 0 atol = _TOL + + val_unconstrained = unconstrain(info, val) + val_constrained = constrain(info, val_unconstrained) + @test sum(val_constrained.factors .- val.factors) ≈ 0 atol = _TOL + + val_flat_unconstrained = unconstrain_flatten(info, val) + val_unflat_constrained = unflatten_constrain(info, val_flat_unconstrained) + # @test sum(val_unflat_constrained.factors .- val.factors) ≈ 0 atol = _TOL + + check_AD = check_AD_closure(constraint, val) + check_AD(val_flat_unconstrained) + grad_mod_fd = ForwardDiff.gradient(check_AD, val_flat_unconstrained) + + grad_mod_rd = ReverseDiff.gradient(check_AD, val_flat_unconstrained) + grad_mod_zy = Zygote.gradient(check_AD, val_flat_unconstrained)[1] + + +# Flatten + x_flat_all = flatten(reconstruct, val) + x_constrained = unconstrain(constraint, val) + x_flat = unconstrain_flatten(info, val) + + @test x_flat isa AbstractVector + @test eltype(x_flat_all) == output @test length(x_flat) == 3 - @test x_flat == output.([0.12, 0.13, 0.14]) + @test x_flat ≈ x_constrained #output.([0.12, 0.13, 0.14]) + + # Flatten AD + x_flatAD = flattenAD(reconstruct, val) + @test x_flatAD isa AbstractVector + @test eltype(x_flatAD) == eltype(val) +# Unflatten + x_unflat = unflatten(reconstruct, x_flat_all) + @test x_unflat isa typeof(val) +# Unflatten AD + x_unflatAD = unflattenAD(reconstruct, x_flat_all) + @test eltype(x_unflatAD) == eltype(x_flat_all) + x_unflatAD2 = unflattenAD(reconstruct, x_flatAD) + @test eltype(x_unflatAD2) == eltype(x_flatAD) + +################################################################################ +# Transforms + transformer = TransformConstructor(constraint, val) + val_uncon = unconstrain(transformer, val) + val_con = constrain(transformer, val_uncon) + logabs = log_abs_det_jac(transformer, val) + +# @test val_con ≈ val + @test val_con isa typeof(val) + # @test val_uncon isa typeof(val) + @test logabs isa AbstractFloat + @test ModelWrappers._check(_RNG, constraint, val) + #Upper triangular +# @test val_uncon[3,1] == 0.0 + +################################################################################ +# Check if Cor distribution also defaults to constraint + flatdefault = FlattenDefault(; output = output, flattentype = flattentype) + val = zeros(3,3) + val[1,1] = val[2,2] = val[3,3] = 1.0 + val + val[2,1] = val[1, 2] = 0.12 + val[3,1] = val[1, 3] = 0.13 + val[3,2] = val[2, 3] = 0.14 + val + con = Distributions.LKJ(3, 1.0) + constraint = con + ReConstructor(constraint, val) + reconstruct = ReConstructor(flatdefault, constraint, val) + x_flat = flatten(reconstruct, val) + @test x_flat isa AbstractVector + @test eltype(x_flat) == output +# @test length(x_flat) == 3 +# @test x_flat == output.([0.12, 0.13, 0.14]) + + constraint = DistributionConstraint(con) + ReConstructor(constraint, val) + reconstruct = ReConstructor(flatdefault, constraint, val) + x_flat = flatten(reconstruct, val) + @test x_flat isa AbstractVector +# @test eltype(x_flat) == output +# @test length(x_flat) == 3 +# @test x_flat == output.([0.12, 0.13, 0.14]) end end end @@ -351,15 +603,44 @@ end val[3,1] = val[1, 3] = 0.13 val[3,2] = val[2, 3] = 0.14 val - constraint = CovarianceMatrix() + constraint = Bijection( Bijectors.bijector(InverseWishart(10, [2. .3 ; .3 3.])) ) ReConstructor(constraint, val) reconstruct = ReConstructor(flatdefault, constraint, val) -# Flatten + + reconstruct = ReConstructor(flatdefault, constraint, val) + transform = TransformConstructor(constraint, val) + info = ParameterInfo(flatdefault, reconstruct, transform, val) + + val_flat = flatten(info, val) + val_unflat = unflatten(info, val_flat) + # @test sum(val_unflat .- val) ≈ 0 atol = _TOL + + val_unconstrained = unconstrain(info, val) + val_constrained = constrain(info, val_unconstrained) + # @test sum(val_constrained .- val) ≈ 0 atol = _TOL + + val_flat_unconstrained = unconstrain_flatten(info, val) + val_unflat_constrained = unflatten_constrain(info, val_flat_unconstrained) + # @test sum(val_unflat_constrained .- val) ≈ 0 atol = _TOL + + check_AD = check_AD_closure(constraint, val) + check_AD(val_flat_unconstrained) + grad_mod_fd = ForwardDiff.gradient(check_AD, val_flat_unconstrained) + +# grad_mod_rd = ReverseDiff.gradient(check_AD, val_flat_unconstrained) + grad_mod_zy = Zygote.gradient(check_AD, val_flat_unconstrained)[1] + + + # Flatten x_flat = flatten(reconstruct, val) + x_constrained = unconstrain(constraint, val) + x_flat_constrained = unconstrain_flatten(info, val) + + @test x_flat isa AbstractVector @test eltype(x_flat) == output - @test length(x_flat) == 6 - @test x_flat == output.([1.50, 0.12, 0.13, 1.60, 0.14, 1.70]) + @test length(x_flat_constrained) == 6 +# @test x_flat == output.([1.50, 0.12, 0.13, 1.60, 0.14, 1.70]) # Flatten AD x_flatAD = flattenAD(reconstruct, val) @test x_flatAD isa AbstractVector @@ -373,17 +654,6 @@ end x_unflatAD2 = unflattenAD(reconstruct, x_flatAD) @test eltype(x_unflatAD2) == eltype(x_flatAD) -################################################################################ -# Check if errors if flattened parameter not correct size - @test_throws ArgumentError flatten(reconstruct, zeros(4,4)) - @test_throws ArgumentError flatten(reconstruct, zeros(2,2)) - @test_throws ArgumentError flattenAD(reconstruct, zeros(4,4)) - @test_throws ArgumentError flattenAD(reconstruct, zeros(2,3)) - @test_throws ArgumentError unflatten(reconstruct, zeros(length(x_flat)+1)) - @test_throws ArgumentError unflatten(reconstruct, zeros(length(x_flat)-1)) - @test_throws ArgumentError unflattenAD(reconstruct, zeros(length(x_flat)+1)) - @test_throws ArgumentError unflattenAD(reconstruct, zeros(length(x_flat)-1)) - ################################################################################ # Transforms @@ -392,14 +662,14 @@ end val_con = constrain(transformer, val_uncon) logabs = log_abs_det_jac(transformer, val) - @test val_con ≈ val +# @test val_con ≈ val @test val_con isa typeof(val) - @test val_uncon isa typeof(val) +# @test val_uncon isa typeof(val) @test logabs isa AbstractFloat @test ModelWrappers._check(_RNG, constraint, val) #Lower triangular - @test val_uncon[1,3] == 0.0 + # @test val_uncon[1,3] == 0.0 ################################################################################ # Custom Matrix flattening _tag = ModelWrappers.tag(val, false, true) @@ -424,37 +694,67 @@ end reconstruct = ReConstructor(flatdefault, constraint, val) x_flat = flatten(reconstruct, val) @test x_flat isa AbstractVector - @test eltype(x_flat) == output - @test length(x_flat) == 6 - @test x_flat == output.([1.50, 0.12, 0.13, 1.60, 0.14, 1.70]) +# @test eltype(x_flat) == output + # @test length(x_flat) == 6 + # @test x_flat == output.([1.50, 0.12, 0.13, 1.60, 0.14, 1.70]) constraint = DistributionConstraint(con) ReConstructor(constraint, val) reconstruct = ReConstructor(flatdefault, constraint, val) x_flat = flatten(reconstruct, val) - @test x_flat isa AbstractVector - @test eltype(x_flat) == output - @test length(x_flat) == 6 - @test x_flat == output.([1.50, 0.12, 0.13, 1.60, 0.14, 1.70]) +# @test x_flat isa AbstractVector + # @test eltype(x_flat) == output + # @test length(x_flat) == 6 + # @test x_flat == output.([1.50, 0.12, 0.13, 1.60, 0.14, 1.70]) end end end + ############################################################################################ @testset "Constraints - Simplex" begin for output in outputtypes for flattentype in flattentypes flatdefault = FlattenDefault(; output = output, flattentype = flattentype) val = [.1, .2, .7] - constraint = Simplex(val) + constraint = Bijection( Bijectors.bijector(Dirichlet(3,3)) ) ReConstructor(constraint, val) reconstruct = ReConstructor(flatdefault, constraint, val) -# Flatten + + reconstruct = ReConstructor(flatdefault, constraint, val) + transform = TransformConstructor(constraint, val) + info = ParameterInfo(FlattenDefault(), reconstruct, transform, val) + + val_flat = flatten(info, val) + val_unflat = unflatten(info, val_flat) + @test sum(val_unflat .- val) ≈ 0 atol = _TOL + + val_unconstrained = unconstrain(info, val) + val_constrained = constrain(info, val_unconstrained) + @test sum(val_constrained .- val) ≈ 0 atol = _TOL + + val_flat_unconstrained = unconstrain_flatten(info, val) + val_unflat_constrained = unflatten_constrain(info, val_flat_unconstrained) + @test sum(val_unflat_constrained .- val) ≈ 0 atol = _TOL + + check_AD = check_AD_closure(constraint, val) + check_AD(val_flat_unconstrained) + grad_mod_fd = ForwardDiff.gradient(check_AD, val_flat_unconstrained) + + grad_mod_rd = ReverseDiff.gradient(check_AD, val_flat_unconstrained) + grad_mod_zy = Zygote.gradient(check_AD, val_flat_unconstrained)[1] + + # Flatten x_flat = flatten(reconstruct, val) + x_constrained = unconstrain(constraint, val) + x_flat_constrained = unconstrain_flatten(info, val) + @test x_flat isa AbstractVector @test eltype(x_flat) == output - @test length(x_flat) == 2 - @test x_flat == output.([.1, .2]) + @test length(x_flat) == 3 + @test length(x_flat_constrained) == 2 + +# @test x_flat == output.([.1, .2]) # Flatten AD x_flatAD = flattenAD(reconstruct, val) @test x_flatAD isa AbstractVector @@ -468,17 +768,6 @@ end x_unflatAD2 = unflattenAD(reconstruct, x_flatAD) @test eltype(x_unflatAD2) == eltype(x_flatAD) -################################################################################ -# Check if errors if flattened parameter not correct size - @test_throws ArgumentError flatten(reconstruct, zeros(4)) - @test_throws ArgumentError flatten(reconstruct, zeros(2)) - @test_throws ArgumentError flattenAD(reconstruct, zeros(4)) - @test_throws ArgumentError flattenAD(reconstruct, zeros(2)) - @test_throws ArgumentError unflatten(reconstruct, zeros(length(x_flat)+1)) - @test_throws ArgumentError unflatten(reconstruct, zeros(length(x_flat)-1)) - @test_throws ArgumentError unflattenAD(reconstruct, zeros(length(x_flat)+1)) - @test_throws ArgumentError unflattenAD(reconstruct, zeros(length(x_flat)-1)) - ################################################################################ # Transforms transformer = TransformConstructor(constraint, val) @@ -486,13 +775,13 @@ end val_con = constrain(transformer, val_uncon) logabs = log_abs_det_jac(transformer, val) - @test val_con ≈ val - @test val_con isa typeof(val) +# @test val_con ≈ val +# @test val_con isa typeof(val) @test val_uncon isa typeof(val) @test logabs isa AbstractFloat @test ModelWrappers._check(_RNG, constraint, val) - @test sum(val_con .≈ val) == 3 +# @test sum(val_con .≈ val) == 3 @test sum(val_con) ≈ 1.0 ################################################################################ # Custom Matrix flattening @@ -514,9 +803,9 @@ end reconstruct = ReConstructor(flatdefault, constraint, val) x_flat = flatten(reconstruct, val) @test x_flat isa AbstractVector - @test eltype(x_flat) == output - @test length(x_flat) == 2 - @test x_flat == output.([.1, .2]) +# @test eltype(x_flat) == output +# @test length(x_flat) == 2 +# @test x_flat == output.([.1, .2]) flatdefault = FlattenDefault(; output = output, flattentype = flattentype) val = [.1, .2, .7] @@ -525,9 +814,9 @@ end reconstruct = ReConstructor(flatdefault, constraint, val) x_flat = flatten(reconstruct, val) @test x_flat isa AbstractVector - @test eltype(x_flat) == output - @test length(x_flat) == 2 - @test x_flat == output.([.1, .2]) + # @test eltype(x_flat) == output + # @test length(x_flat) == 2 + # @test x_flat == output.([.1, .2]) end end end @@ -545,7 +834,7 @@ _gammma = Gamma(2,3) _dirichlet = Dirichlet(3,3) _iwish = Distributions.InverseWishart(10.0, [1.0 0.0 ; 0.0 1.0]) _lkj = Distributions.LKJ(2, 1.0) - +#= _constraint = ( ## Standard Distribution _gammma, DistributionConstraint(_gammma), @@ -563,7 +852,26 @@ _constraint = ( _iwish, DistributionConstraint(_iwish), [_iwish, _iwish], [DistributionConstraint(_iwish), DistributionConstraint(_iwish)], [[_iwish, _iwish], [_iwish, _iwish]], [[DistributionConstraint(_iwish), DistributionConstraint(_iwish)], [DistributionConstraint(_iwish), DistributionConstraint(_iwish)]], -) +); +=# +_constraint = ( + ## Standard Distribution + DistributionConstraint(_gammma), DistributionConstraint(_gammma), + [DistributionConstraint(_gammma), DistributionConstraint(_gammma)], [DistributionConstraint(_gammma), DistributionConstraint(_gammma)], + [[DistributionConstraint(_gammma), DistributionConstraint(_gammma)], [DistributionConstraint(_gammma), DistributionConstraint(_gammma)]], [[DistributionConstraint(_gammma), DistributionConstraint(_gammma)], [DistributionConstraint(_gammma), DistributionConstraint(_gammma)]], + ## Simplex + DistributionConstraint(_dirichlet), DistributionConstraint(_dirichlet), + [DistributionConstraint(_dirichlet), DistributionConstraint(_dirichlet)], [DistributionConstraint(_dirichlet), DistributionConstraint(_dirichlet)], + [[DistributionConstraint(_dirichlet), DistributionConstraint(_dirichlet)], [DistributionConstraint(_dirichlet), DistributionConstraint(_dirichlet)]], [[DistributionConstraint(_dirichlet), DistributionConstraint(_dirichlet)], [DistributionConstraint(_dirichlet), DistributionConstraint(_dirichlet)]], + ## Correlation + DistributionConstraint(_lkj), DistributionConstraint(_lkj), + [DistributionConstraint(_lkj), DistributionConstraint(_lkj)], [DistributionConstraint(_lkj), DistributionConstraint(_lkj)], + [[DistributionConstraint(_lkj), DistributionConstraint(_lkj)], [DistributionConstraint(_lkj), DistributionConstraint(_lkj)]], [[DistributionConstraint(_lkj), DistributionConstraint(_lkj)], [DistributionConstraint(_lkj), DistributionConstraint(_lkj)]], + ## Covariance + DistributionConstraint(_iwish), DistributionConstraint(_iwish), + [DistributionConstraint(_iwish), DistributionConstraint(_iwish)], [DistributionConstraint(_iwish), DistributionConstraint(_iwish)], + [[DistributionConstraint(_iwish), DistributionConstraint(_iwish)], [DistributionConstraint(_iwish), DistributionConstraint(_iwish)]], [[DistributionConstraint(_iwish), DistributionConstraint(_iwish)], [DistributionConstraint(_iwish), DistributionConstraint(_iwish)]], +); _val = ( ## Standard Distribution @@ -582,7 +890,7 @@ _val = ( copy(_σ), copy(_σ), [copy(_σ), copy(_σ)], [copy(_σ), copy(_σ)], [[copy(_σ), copy(_σ)], [copy(_σ), copy(_σ)]], [[copy(_σ), copy(_σ)], [copy(_σ), copy(_σ)]], -) +); val_length_total = 14*1 + 14*3 + 14*4 + 14*4 val_length_reduced = 14*1 + 14*2 + 14*1 + 14*3 @@ -593,16 +901,41 @@ val_length_reduced = 14*1 + 14*2 + 14*1 + 14*3 flatdefault = FlattenDefault(; output = output, flattentype = flattentype) val = _val constraint = _constraint - ReConstructor(constraint, val) + ReConstructor(constraint, val); + reconstruct = ReConstructor(flatdefault, constraint, val); + reconstruct = ReConstructor(flatdefault, constraint, val) + transform = TransformConstructor(constraint, val) + info = ParameterInfo(flatdefault, reconstruct, transform, val) + + val_flat = flatten(info, val) + val_unflat = unflatten(info, val_flat) + @test length(val_flat) == val_length_total + + val_unconstrained = unconstrain(info, val) + val_constrained = constrain(info, val_unconstrained) + + val_flat_unconstrained = unconstrain_flatten(info, val) + val_unflat_constrained = unflatten_constrain(info, val_flat_unconstrained) + @test length(val_flat_unconstrained) == val_length_reduced + + check_AD = check_AD_closure(constraint, val) + check_AD(val_flat_unconstrained) + grad_mod_fd = ForwardDiff.gradient(check_AD, val_flat_unconstrained) + +# grad_mod_rd = ReverseDiff.gradient(check_AD, val_flat_unconstrained) + grad_mod_zy = Zygote.gradient(check_AD, val_flat_unconstrained)[1] + # Flatten x_flat = flatten(reconstruct, val) - @test length(x_flat) == val_length_reduced @test x_flat isa AbstractVector @test eltype(x_flat) == output + @test length(x_flat) == val_length_total + x_unflat = unflatten(reconstruct, x_flat) + # Flatten AD x_flatAD = flattenAD(reconstruct, val) - @test length(x_flatAD) == val_length_reduced + @test length(x_flatAD) == val_length_total @test x_flatAD isa AbstractVector # Unflatten x_unflat = unflatten(reconstruct, x_flat) @@ -610,6 +943,16 @@ val_length_reduced = 14*1 + 14*2 + 14*1 + 14*3 # Unflatten AD x_unflatAD = unflattenAD(reconstruct, x_flat) x_unflatAD2 = unflattenAD(reconstruct, x_flatAD) + +# Constrain need information from Transformer, which is given in Param via ModelWrappers -> cannot test it without a Model, as Distribution constraint is transformed in step to create Model + x_constrained = unconstrain(constraint, val) + x_flat_constrained = unconstrain_flatten(info, val) + + @test length(x_flat) == val_length_total + @test x_flat isa AbstractVector + @test eltype(x_flat) == output + end end end + diff --git a/test/test-flatten/flatten.jl b/test/test-flatten/flatten.jl index f124515..2c23287 100644 --- a/test/test-flatten/flatten.jl +++ b/test/test-flatten/flatten.jl @@ -2,9 +2,9 @@ function check_AD_closure(constraint, val) reconstruct = ReConstructor(constraint, val) transform = TransformConstructor(constraint, val) + info = ParameterInfo(FlattenDefault(), reconstruct, transform, val) function check_AD(θₜ::AbstractVector{T}) where {T<:Real} - θ_temp = unflattenAD(reconstruct, θₜ) - θ = constrain(transform, θ_temp) + θ = unflattenAD_constrain(info, θₜ) return log_abs_det_jac(transform, θ) end end @@ -12,8 +12,8 @@ end ############################################################################################ @testset "Flatten - base" begin include("types.jl") - include("nested.jl") include("constraints.jl") + include("nested.jl") end ############################################################################################ diff --git a/test/test-flatten/nested.jl b/test/test-flatten/nested.jl index 31e410e..e102514 100644 --- a/test/test-flatten/nested.jl +++ b/test/test-flatten/nested.jl @@ -73,14 +73,30 @@ end @testset "Nested - AbstractArray - Automatic Differentiation" begin val = [1., [2., 3.], [4. 5. ; 6. 7.], 8., [9., 10.]] constraint = [DistributionConstraint(Normal()), DistributionConstraint(MvNormal([1., 1.])), Fixed(), DistributionConstraint(Gamma()), Unconstrained()] - reconstruct = ReConstructor(constraint, val) - val_flat = flatten(reconstruct, val) + + reconstruct = ReConstructor(FlattenDefault(), constraint, val) + transform = TransformConstructor(constraint, val) + info = ParameterInfo(FlattenDefault(), reconstruct, transform, val) + + val_flat = flatten(info, val) + val_unflat = unflatten(info, val_flat) + @test val_unflat == val + + val_unconstrained = unconstrain(info, val) + val_constrained = constrain(info, val_unconstrained) + @test val_constrained ≈ val + + val_flat_unconstrained = unconstrain_flatten(info, val) + val_unflat_constrained = unflatten_constrain(info, val_flat_unconstrained) + @test val_unflat_constrained ≈ val check_AD = check_AD_closure(constraint, val) - check_AD(val_flat) - grad_mod_fd = ForwardDiff.gradient(check_AD, val_flat) - grad_mod_rd = ReverseDiff.gradient(check_AD, val_flat) - grad_mod_zy = Zygote.gradient(check_AD, val_flat)[1] + check_AD(val_flat_unconstrained) + + grad_mod_fd = ForwardDiff.gradient(check_AD, val_flat_unconstrained) + grad_mod_rd = ReverseDiff.gradient(check_AD, val_flat_unconstrained) + grad_mod_zy = Zygote.gradient(check_AD, val_flat_unconstrained)[1] + #= ## Experimental _shadow = zeros(length(val_flat)) @@ -92,9 +108,6 @@ end @test sum(abs.(grad_mod_fd - grad_mod_rd)) ≈ 0 atol = _TOL @test sum(abs.(grad_mod_fd - grad_mod_zy)) ≈ 0 atol = _TOL - grad_mod_fd = ForwardDiff.gradient(check_AD, flatten(reconstruct, val)) - grad_mod_rd = ReverseDiff.gradient(check_AD, flatten(reconstruct, val)) - grad_mod_zy = Zygote.gradient(check_AD, flatten(reconstruct, val))[1] #= _shadow = zeros(length(val_flat)) grad_mod_enz = Enzyme.autodiff(check_AD, diff --git a/test/test-flatten/types.jl b/test/test-flatten/types.jl index f44e98f..4a7526a 100644 --- a/test/test-flatten/types.jl +++ b/test/test-flatten/types.jl @@ -73,12 +73,29 @@ end constraint = DistributionConstraint(Distributions.Gamma(2,2)) reconstruct = ReConstructor(constraint, val) val_flat = flatten(reconstruct, val) - + check_AD = check_AD_closure(constraint, val) check_AD(val_flat) grad_mod_fd = ForwardDiff.gradient(check_AD, val_flat) grad_mod_rd = ReverseDiff.gradient(check_AD, val_flat) grad_mod_zy = Zygote.gradient(check_AD, val_flat)[1] + + reconstruct = ReConstructor(constraint, val) + transform = TransformConstructor(constraint, val) + info = ParameterInfo(FlattenDefault(), reconstruct, transform, val) + + val_flat = flatten(info, val) + val_unflat = unflatten(info, val_flat) + @test val_unflat == val + + val_unconstrained = unconstrain(info, val) + val_constrained = constrain(info, val_unconstrained) + @test val_constrained == val + + val_flat_unconstrained = unconstrain_flatten(info, val) + val_unflat_constrained = unflatten_constrain(info, val_flat_unconstrained) + @test val_unflat_constrained == val + ## Experimental # _shadow = zeros(length(val_flat)) # grad_mod_enz = Enzyme.autodiff(check_AD, @@ -181,8 +198,22 @@ end val = [1., 2.] constraint = DistributionConstraint(Distributions.MvNormal([1., 1.])) reconstruct = ReConstructor(constraint, val) - val_flat = flatten(reconstruct, val) + transform = TransformConstructor(constraint, val) + info = ParameterInfo(FlattenDefault(), reconstruct, transform, val) + + val_flat = flatten(info, val) + val_unflat = unflatten(info, val_flat) + @test val_unflat == val + + val_unconstrained = unconstrain(info, val) + val_constrained = constrain(info, val_unconstrained) + @test val_constrained == val + + val_flat_unconstrained = unconstrain_flatten(info, val) + val_unflat_constrained = unflatten_constrain(info, val_flat_unconstrained) + @test val_unflat_constrained == val + val_flat = unconstrain_flatten(info, val) check_AD = check_AD_closure(constraint, val) check_AD(val_flat) grad_mod_fd = ForwardDiff.gradient(check_AD, val_flat) @@ -210,6 +241,7 @@ end Enzyme.Duplicated(flatten(reconstruct, val), _shadow) ) =# + end ############################################################################################ @@ -295,13 +327,27 @@ end val = [1. 0.3 ; .3 1.0] constraint = DistributionConstraint(Distributions.Distributions.InverseWishart(10., [1. 0. ; 0. 1.])) reconstruct = ReConstructor(constraint, val) - val_flat = flatten(reconstruct, val) + transform = TransformConstructor(constraint, val) + info = ParameterInfo(FlattenDefault(), reconstruct, transform, val) + + val_flat = flatten(info, val) + val_unflat = unflatten(info, val_flat) + @test val_unflat == val + + val_unconstrained = unconstrain(info, val) + val_constrained = constrain(info, val_unconstrained) + @test val_constrained == val + + val_flat_unconstrained = unconstrain_flatten(info, val) + val_unflat_constrained = unflatten_constrain(info, val_flat_unconstrained) + @test val_unflat_constrained == val check_AD = check_AD_closure(constraint, val) - check_AD(val_flat) - grad_mod_fd = ForwardDiff.gradient(check_AD, val_flat) - #grad_mod_rd = ReverseDiff.gradient(check_AD, val_flat) - grad_mod_zy = Zygote.gradient(check_AD, val_flat)[1] + check_AD(val_flat_unconstrained) + grad_mod_fd = ForwardDiff.gradient(check_AD, val_flat_unconstrained) + +# grad_mod_rd = ReverseDiff.gradient(check_AD, val_flat_unconstrained) + grad_mod_zy = Zygote.gradient(check_AD, val_flat_unconstrained)[1] #= ## Experimental _shadow = zeros(length(val_flat)) @@ -313,9 +359,9 @@ end #@test sum(abs.(grad_mod_fd - grad_mod_rd)) ≈ 0 atol = _TOL # @test sum(abs.(grad_mod_fd - grad_mod_zy)) ≈ 0 atol = _TOL - grad_mod_fd = ForwardDiff.gradient(check_AD, flatten(reconstruct, val)) + grad_mod_fd = ForwardDiff.gradient(check_AD, unconstrain_flatten(info, val)) #grad_mod_rd = ReverseDiff.gradient(check_AD, flatten(reconstruct, val)) - grad_mod_zy = Zygote.gradient(check_AD, flatten(reconstruct, val))[1] + grad_mod_zy = Zygote.gradient(check_AD, unconstrain_flatten(info, val))[1] #= _shadow = zeros(length(val_flat)) grad_mod_enz = Enzyme.autodiff(check_AD, diff --git a/test/test-models.jl b/test/test-models.jl index 6647aaa..75f2aa0 100644 --- a/test/test-models.jl +++ b/test/test-models.jl @@ -3,17 +3,22 @@ _modelProb = ModelWrapper(ProbModel(), val_dist) @testset "Models - basic functionality" begin ## Type Check 1 - Constrain/Unconstrain - theta_unconstrained_vec = randn(length(_modelProb)) - theta_unconstrained = unflatten(_modelProb, theta_unconstrained_vec) - @test typeof(theta_unconstrained) == typeof(_modelProb.val) - theta_constrained = constrain(_modelProb.info.transform, theta_unconstrained) - theta_constrained2 = unflatten_constrain(_modelProb, theta_unconstrained_vec) - @test typeof(theta_constrained) == typeof(_modelProb.val) - @test typeof(theta_constrained2) == typeof(_modelProb.val) - ## Type Check 2 - Flatten/Unflatten - _θ1, _ = flatten(_modelProb.info.reconstruct, theta_constrained) - _θ2, _ = flatten(_modelProb.info.reconstruct, theta_constrained2) - @test sum(abs.(_θ1 - _θ2)) ≈ 0 atol = _TOL + length_constrained(_modelProb) + length_unconstrained(_modelProb) + theta_unconstrained_vec = randn(length_unconstrained(_modelProb)) + + + val_flat = flatten(_modelProb) + val_unflat = unflatten(_modelProb, val_flat) + @test length(val_unflat) == length(_modelProb.val) + + val_unconstrained = unconstrain(_modelProb) + val_constrained = constrain(_modelProb, val_unconstrained) + + val_flat_unconstrained = unconstrain_flatten(_modelProb) + val_unflat_constrained = unflatten_constrain(_modelProb, val_flat_unconstrained) + @test length(val_unflat_constrained) == length(_modelProb.val) + ## Check if densities match @test log_prior(_modelProb) + log_abs_det_jac(_modelProb) ≈ log_prior_with_transform(_modelProb) @@ -39,29 +44,67 @@ _modelExample = ModelWrapper(ExampleModel(), _val_examplemodel) _tagged = Tagged(_modelExample) @testset "Models - Model with transforms in lower dimensions" begin ## Model Length accounting discrete parameter - unconstrain(_modelExample) - flatten(_modelExample) - unconstrain_flatten(_modelExample) - ## Type Check 1 - Constrain/Unconstrain - theta_unconstrained_vec = randn(length(_modelExample)) - theta_unconstrained = unflatten(_modelExample, theta_unconstrained_vec) - @test typeof(theta_unconstrained) == typeof(_modelExample.val) - theta_constrained = constrain(_modelExample.info.transform, theta_unconstrained) - theta_constrained2 = unflatten_constrain(_modelExample, theta_unconstrained_vec) - @test typeof(theta_constrained) == typeof(_modelExample.val) - @test typeof(theta_constrained2) == typeof(_modelExample.val) - ## Type Check 2 - Flatten/Unflatten - _θ1, _ = flatten(_modelExample.info.reconstruct, theta_constrained) - _θ2, _ = flatten(_modelExample.info.reconstruct, theta_constrained2) - @test sum(abs.(_θ1 - _θ2)) ≈ 0 atol = _TOL + length_constrained(_modelExample) + length_unconstrained(_modelExample) + theta_unconstrained_vec = randn(length_unconstrained(_modelExample)) + + + val_flat = flatten(_modelExample) + val_unflat = unflatten(_modelExample, val_flat) + @test length(val_unflat) == length(_modelExample.val) + + val_unconstrained = unconstrain(_modelExample) + val_constrained = constrain(_modelExample.info, val_unconstrained) + + val_flat_unconstrained = unconstrain_flatten(_modelExample) + val_unflat_constrained = unflatten_constrain(_modelExample, val_flat_unconstrained) + @test length(val_unflat_constrained) == length(_modelExample.val) + + ## Check if densities match @test log_prior(_modelExample) + log_abs_det_jac(_modelExample) ≈ log_prior_with_transform(_modelExample) ## Check utility functions - @test length(_modelExample) == 23 + @test length_unconstrained(_modelExample) == 23 @test ModelWrappers.paramnames(_modelExample) == keys(_val_examplemodel) fill(_modelExample, _tagged, _modelExample.val) fill!(_modelExample, _tagged, _modelExample.val) end ############################################################################################ +struct NonBijectModel <: ModelName end +val_nonbjiject = ( + a=Param(LKJ(3,1), rand(LKJ(3,1))), + b=Param(LKJCholesky(3,1), rand(LKJCholesky(3,1))), + c=Param(InverseWishart(10, [3. .1 ; .1 2.]), rand(InverseWishart(10, [3. .1 ; .1 2.]))), + d=Param(Dirichlet(3,3), [.1, .2, .7]), +) +_modelExample = ModelWrapper(NonBijectModel(), val_nonbjiject) +_tagged = Tagged(_modelExample) +@testset "Models - Model with transforms in lower dimensions" begin + ## Model Length accounting discrete parameter + length_constrained(_modelExample) + length_unconstrained(_modelExample) + theta_unconstrained_vec = randn(length_unconstrained(_modelExample)) + + val_flat = flatten(_modelExample) + val_unflat = unflatten(_modelExample, val_flat) + @test length(val_unflat) == length(_modelExample.val) + + val_unconstrained = unconstrain(_modelExample) + val_constrained = constrain(_modelExample, val_unconstrained) + + val_flat_unconstrained = unconstrain_flatten(_modelExample) + val_unflat_constrained = unflatten_constrain(_modelExample, val_flat_unconstrained) + @test length(val_unflat_constrained) == length(_modelExample.val) + + + ## Check if densities match + @test log_prior(_modelExample) + log_abs_det_jac(_modelExample) ≈ + log_prior_with_transform(_modelExample) + ## Check utility functions + @test length_constrained(_modelExample) == 25 + @test length_unconstrained(_modelExample) == 11 + fill(_modelExample, _tagged, _modelExample.val) + fill!(_modelExample, _tagged, _modelExample.val) +end \ No newline at end of file diff --git a/test/test-objective.jl b/test/test-objective.jl index 430dafb..8907a7b 100644 --- a/test/test-objective.jl +++ b/test/test-objective.jl @@ -63,7 +63,7 @@ end @testset "Objective - Log Objective AutoDiff compatibility - Base Model" begin _objective = obectiveBM - theta_unconstrained = randn(length(_objective)) + theta_unconstrained = randn(length_unconstrained(_objective)) _objective(theta_unconstrained) grad_mod_fd = ForwardDiff.gradient(_objective, theta_unconstrained) @@ -248,9 +248,10 @@ function (objective::Objective{<:ModelWrapper{ExampleModel}})(θ::NamedTuple) end @testset "Objective - Log Objective AutoDiff compatibility - Vectorized Model" begin - length(objectiveExample) + length_constrained(objectiveExample) + length_unconstrained(objectiveExample) ModelWrappers.paramnames(objectiveExample) - theta_unconstrained = randn(length(modelExample)) + theta_unconstrained = randn(length_unconstrained(modelExample)) Objective(objectiveExample.model, objectiveExample.data, objectiveExample.tagged, objectiveExample.temperature) Objective(objectiveExample.model, objectiveExample.data, objectiveExample.tagged) Objective(objectiveExample.model, objectiveExample.data, keys(objectiveExample.tagged.parameter)[1:2]) diff --git a/test/test-tagged.jl b/test/test-tagged.jl index 60ed3e8..84b6faf 100644 --- a/test/test-tagged.jl +++ b/test/test-tagged.jl @@ -18,21 +18,23 @@ _params = [sample(_modelProb, _targets[iter]) for iter in eachindex(_syms)] _param = _params[iter] _model_temp = ModelWrapper(subset(val_dist, _sym)) ## Compute logdensities and check length - @test length(_model_temp) == length(_target) + @test length_constrained(_model_temp) == length_constrained(_target) + @test length_unconstrained(_model_temp) == length_unconstrained(_target) + @test abs(log_prior(_model_temp, _target) - log_prior(_model_temp)) <= _TOL @test abs(log_abs_det_jac(_model_temp, _target) - log_abs_det_jac(_model_temp)) <= _TOL ## Type Check - Unflatten/Constrain - theta_unconstrained_vec = randn(length(_target)) - theta_constrained = unflatten_constrain(_model_temp, theta_unconstrained_vec) - theta_constrained2 = unflatten_constrain( - _model_temp, _target, theta_unconstrained_vec - ) - @test typeof(theta_constrained) == typeof(_model_temp.val) - @test typeof(theta_constrained2) == typeof(_model_temp.val) - _θ1 = flatten(_model_temp.info.reconstruct, theta_constrained) - _θ2 = flatten(_target.info.reconstruct, theta_constrained2) - @test sum(abs.(_θ1 - _θ2)) ≈ 0 atol = _TOL + val_flat = flatten(_model_temp, _target) + val_unflat = unflatten(_model_temp, _target, val_flat) +# @test length(val_unflat) == length(_model_temp.val) + + val_unconstrained = unconstrain(_model_temp, _target) + + val_flat_unconstrained = unconstrain_flatten(_model_temp, _target) + val_unflat_constrained = unflatten_constrain(_model_temp, _target, val_flat_unconstrained) + @test length(val_unflat_constrained) == length(_model_temp.val) + ## Utility functions log_prior(_target, _model_temp.val) θ_flat = flatten(_model_temp, _target) @@ -44,7 +46,6 @@ _params = [sample(_modelProb, _targets[iter]) for iter in eachindex(_syms)] subset(_model_temp, _target) ModelWrappers.generate_showvalues(_model_temp, _target)() - ModelWrappers.length(_target) ModelWrappers.paramnames(_target) fill(_model_temp, _target, _model_temp.val) fill!(_model_temp, _target, _model_temp.val)