diff --git a/Project.toml b/Project.toml index 59afd2b6..c53a2ce6 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Bijectors" uuid = "76274a88-744f-5084-9051-94815aaf08c4" -version = "0.13.12" +version = "0.13.13" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" @@ -43,7 +43,7 @@ ArgCheck = "1, 2" ChainRules = "1" ChainRulesCore = "0.10.11, 1" ChangesOfVariables = "0.1" -Compat = "3, 4" +Compat = "3.46, 4.2" Distributions = "0.25.33" ForwardDiff = "0.10" DistributionsAD = "0.6" diff --git a/ext/BijectorsReverseDiffExt.jl b/ext/BijectorsReverseDiffExt.jl index a733bd71..4489cb26 100644 --- a/ext/BijectorsReverseDiffExt.jl +++ b/ext/BijectorsReverseDiffExt.jl @@ -268,7 +268,6 @@ end @grad_from_chainrules _link_chol_lkj(x::TrackedMatrix) @grad_from_chainrules _link_chol_lkj_from_upper(x::TrackedMatrix) @grad_from_chainrules _link_chol_lkj_from_lower(x::TrackedMatrix) -@grad_from_chainrules _inv_link_chol_lkj(x::TrackedVector) cholesky_lower(X::TrackedMatrix) = track(cholesky_lower, X) @grad function cholesky_lower(X_tracked::TrackedMatrix) diff --git a/ext/BijectorsTrackerExt.jl b/ext/BijectorsTrackerExt.jl index b44cf3a3..9d04ea32 100644 --- a/ext/BijectorsTrackerExt.jl +++ b/ext/BijectorsTrackerExt.jl @@ -338,106 +338,16 @@ Bijectors.lower_triangular(A::TrackedMatrix) = track(Bijectors.lower_triangular, end Bijectors._inv_link_chol_lkj(y::TrackedVector) = track(Bijectors._inv_link_chol_lkj, y) -@grad function Bijectors._inv_link_chol_lkj(y_tracked::TrackedVector) - y = data(y_tracked) - K = _triu1_dim_from_length(length(y)) - - W = similar(y, K, K) - - z_vec = similar(y) - tmp_vec = similar(y) - - idx = 1 - @inbounds for j in 1:K - W[1, j] = 1 - for i in 2:j - z = tanh(y[idx]) - tmp = W[i - 1, j] - - z_vec[idx] = z - tmp_vec[idx] = tmp - idx += 1 - - W[i - 1, j] = z * tmp - W[i, j] = tmp * sqrt(1 - z^2) - end - for i in (j + 1):K - W[i, j] = 0 - end - end - - function pullback_inv_link_chol_lkj(ΔW) - LinearAlgebra.checksquare(ΔW) - - Δy = zero(y) - - @inbounds for j in 1:K - idx_up_to_prev_column = ((j - 1) * (j - 2) ÷ 2) - Δtmp = ΔW[j, j] - for i in j:-1:2 - idx = idx_up_to_prev_column + i - 1 - Δz = - ΔW[i - 1, j] * tmp_vec[idx] - - Δtmp * tmp_vec[idx] / sqrt(1 - z_vec[idx]^2) * z_vec[idx] - Δy[idx] = Δz / cosh(y[idx])^2 - Δtmp = ΔW[i - 1, j] * z_vec[idx] + Δtmp * sqrt(1 - z_vec[idx]^2) - end - end - - return (Δy,) - end - - return W, pullback_inv_link_chol_lkj -end - Bijectors._inv_link_chol_lkj(y::TrackedMatrix) = track(Bijectors._inv_link_chol_lkj, y) -@grad function Bijectors._inv_link_chol_lkj(y_tracked::TrackedMatrix) +@grad function Bijectors._inv_link_chol_lkj(y_tracked::Union{TrackedVector,TrackedMatrix}) y = data(y_tracked) + W_logJ, back = Bijectors._inv_link_chol_lkj_rrule(y) - K = LinearAlgebra.checksquare(y) - - w = similar(y) - - z_mat = similar(y) # cache for adjoint - tmp_mat = similar(y) - - @inbounds for j in 1:K - w[1, j] = 1 - for i in 2:j - z = tanh(y[i - 1, j]) - tmp = w[i - 1, j] - - z_mat[i, j] = z - tmp_mat[i, j] = tmp - - w[i - 1, j] = z * tmp - w[i, j] = tmp * sqrt(1 - z^2) - end - for i in (j + 1):K - w[i, j] = 0 - end - end - - function pullback_inv_link_chol_lkj(Δw) - LinearAlgebra.checksquare(Δw) - - Δy = zero(y) - - @inbounds for j in 1:K - Δtmp = Δw[j, j] - for i in j:-1:2 - Δz = - Δw[i - 1, j] * tmp_mat[i, j] - - Δtmp * tmp_mat[i, j] / sqrt(1 - z_mat[i, j]^2) * z_mat[i, j] - Δy[i - 1, j] = Δz / cosh(y[i - 1, j])^2 - Δtmp = Δw[i - 1, j] * z_mat[i, j] + Δtmp * sqrt(1 - z_mat[i, j]^2) - end - end - - return (Δy,) + function pullback_inv_link_chol_lkj(ΔW_ΔlogJ) + return (back(ΔW_ΔlogJ),) end - return w, pullback_inv_link_chol_lkj + return W_logJ, pullback_inv_link_chol_lkj end Bijectors._link_chol_lkj(w::TrackedMatrix) = track(Bijectors._link_chol_lkj, w) diff --git a/src/Bijectors.jl b/src/Bijectors.jl index 0ef63a2d..50a8f07e 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -82,6 +82,10 @@ if VERSION < v"1.1" using Compat: eachcol end +if VERSION < v"1.9" + using Compat: stack +end + const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_BIJECTORS", "0"))) _debug(str) = @debug str diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 1e68fce8..a7fe0b92 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -71,9 +71,13 @@ function transform(b::CorrBijector, X::AbstractMatrix{<:Real}) return r end -function transform(ib::Inverse{CorrBijector}, y::AbstractMatrix{<:Real}) - w = _inv_link_chol_lkj(y) - return pd_from_upper(w) +function with_logabsdet_jacobian(ib::Inverse{CorrBijector}, y::AbstractMatrix{<:Real}) + U, logJ = _inv_link_chol_lkj(y) + K = size(U, 1) + for j in 2:(K - 1) + logJ += (K - j) * log(U[j, j]) + end + return pd_from_upper(U), logJ end logabsdetjac(::Inverse{CorrBijector}, Y::AbstractMatrix{<:Real}) = _logabsdetjac_inv_corr(Y) @@ -131,8 +135,15 @@ function logabsdetjac(b::VecCorrBijector, x) return -logabsdetjac(inverse(b), b(x)) end -function transform(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) - return pd_from_upper(_inv_link_chol_lkj(y)) +function with_logabsdet_jacobian(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) + U_logJ = _inv_link_chol_lkj(y) + # workaround for `Tracker.TrackedTuple` not supporting iteration + U, logJ = U_logJ[1], U_logJ[2] + K = size(U, 1) + for j in 2:(K - 1) + logJ += (K - j) * log(U[j, j]) + end + return pd_from_upper(U), logJ end function logabsdetjac(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) @@ -225,15 +236,16 @@ function logabsdetjac(b::VecCholeskyBijector, x) return -logabsdetjac(inverse(b), b(x)) end -function transform(b::Inverse{VecCholeskyBijector}, y::AbstractVector{<:Real}) +function with_logabsdet_jacobian(b::Inverse{VecCholeskyBijector}, y::AbstractVector{<:Real}) + factors, logJ = _inv_link_chol_lkj(y) if b.orig.mode === :U # This Cholesky constructor is compatible with Julia v1.6 # for later versions Cholesky(::UpperTriangular) works - return Cholesky(_inv_link_chol_lkj(y), 'U', 0) + return Cholesky(factors, 'U', 0), logJ else # No need to check for === :L, as it is checked in the VecCholeskyBijector constructor. # HACK: Need to make materialize the transposed matrix to avoid numerical instabilities. # If we don't, the return-type can be both `Matrix` and `Transposed`. - return Cholesky(transpose_eager(_inv_link_chol_lkj(y)), 'L', 0) + return Cholesky(transpose_eager(factors), 'L', 0), logJ end end @@ -281,48 +293,44 @@ which is the above implementation. function _link_chol_lkj(W::AbstractMatrix) K = LinearAlgebra.checksquare(W) - z = similar(W) # z is also UpperTriangular. + y = similar(W) # z is also UpperTriangular. # Some zero filling can be avoided. Though diagnoal is still needed to be filled with zero. - # This block can't be integrated with loop below, because W[1,1] != 0. - @inbounds z[:, 1] .= 0 - - @inbounds for j in 2:K - z[1, j] = atanh(W[1, j]) - tmp = sqrt(1 - W[1, j]^2) - for i in 2:(j - 1) - p = W[i, j] / tmp - tmp *= sqrt(1 - p^2) - z[i, j] = atanh(p) + @inbounds for j in 1:K + remainder_sq = one(eltype(W)) + for i in 1:(j - 1) + z = W[i, j] / sqrt(remainder_sq) + y[i, j] = atanh(z) + remainder_sq -= W[i, j]^2 end for i in j:K - z[i, j] = 0 + y[i, j] = 0 end end - return z + return y end function _link_chol_lkj_from_upper(W::AbstractMatrix) K = LinearAlgebra.checksquare(W) N = ((K - 1) * K) ÷ 2 # {K \choose 2} free parameters - z = similar(W, N) + y = similar(W, N) idx = 1 @inbounds for j in 2:K - z[idx] = atanh(W[1, j]) + y[idx] = atanh(W[1, j]) idx += 1 - tmp = sqrt(1 - W[1, j]^2) + remainder_sq = 1 - W[1, j]^2 for i in 2:(j - 1) - p = W[i, j] / tmp - tmp *= sqrt(1 - p^2) - z[idx] = atanh(p) + z = W[i, j] / sqrt(remainder_sq) + y[idx] = atanh(z) + remainder_sq -= W[i, j]^2 idx += 1 end end - return z + return y end _link_chol_lkj_from_lower(W::AbstractMatrix) = _link_chol_lkj_from_upper(transpose_eager(W)) @@ -333,47 +341,120 @@ _link_chol_lkj_from_lower(W::AbstractMatrix) = _link_chol_lkj_from_upper(transpo Inverse link function for cholesky factor. """ function _inv_link_chol_lkj(Y::AbstractMatrix) + LinearAlgebra.require_one_based_indexing(Y) K = LinearAlgebra.checksquare(Y) W = similar(Y) + T = float(eltype(W)) + logJ = zero(T) + idx = 1 @inbounds for j in 1:K - W[1, j] = 1 - for i in 2:j - z = tanh(Y[i - 1, j]) - tmp = W[i - 1, j] - W[i - 1, j] = z * tmp - W[i, j] = tmp * sqrt(1 - z^2) + log_remainder = zero(T) # log of proportion of unit vector remaining + for i in 1:(j - 1) + z = tanh(Y[i, j]) + W[i, j] = z * exp(log_remainder) + log_remainder += log1p(-z^2) / 2 + logJ += log_remainder end + logJ += log_remainder + W[j, j] = exp(log_remainder) for i in (j + 1):K W[i, j] = 0 end end - return W + return W, logJ end function _inv_link_chol_lkj(y::AbstractVector) + LinearAlgebra.require_one_based_indexing(y) K = _triu1_dim_from_length(length(y)) W = similar(y, K, K) + T = float(eltype(W)) + logJ = zero(T) idx = 1 @inbounds for j in 1:K - W[1, j] = 1 - for i in 2:j + log_remainder = zero(T) # log of proportion of unit vector remaining + for i in 1:(j - 1) z = tanh(y[idx]) idx += 1 - tmp = W[i - 1, j] - W[i - 1, j] = z * tmp - W[i, j] = tmp * sqrt(1 - z^2) + W[i, j] = z * exp(log_remainder) + log_remainder += log1p(-z^2) / 2 + logJ += log_remainder end + logJ += log_remainder + W[j, j] = exp(log_remainder) for i in (j + 1):K W[i, j] = 0 end end - return W + return W, logJ +end + +# shared reverse-mode AD rule code +function _inv_link_chol_lkj_rrule(y::AbstractVector) + LinearAlgebra.require_one_based_indexing(y) + K = _triu1_dim_from_length(length(y)) + + W = similar(y, K, K) + T = typeof(log(one(eltype(W)))) + logJ = zero(T) + + z_vec = tanh.(y) + + idx = 1 + W[1, 1] = 1 + @inbounds for j in 2:K + log_remainder = zero(T) # log of proportion of unit vector remaining + for i in 1:(j - 1) + z = z_vec[idx] + idx += 1 + W[i, j] = z * exp(log_remainder) + log_remainder += log1p(-z^2) / 2 + logJ += log_remainder + end + logJ += log_remainder + W[j, j] = exp(log_remainder) + for i in (j + 1):K + W[i, j] = 0 + end + end + + function pullback_inv_link_chol_lkj((ΔW, ΔlogJ)) + LinearAlgebra.require_one_based_indexing(ΔW) + Δy = similar(y) + + idx_local = lastindex(y) + @inbounds for j in K:-1:2 + Δlog_remainder = W[j, j] * ΔW[j, j] + 2ΔlogJ + for i in (j - 1):-1:1 + W_ΔW = W[i, j] * ΔW[i, j] + z = z_vec[idx_local] + Δy[idx_local] = (inv(z) - z) * W_ΔW - z * Δlog_remainder + idx_local -= 1 + Δlog_remainder += ΔlogJ + W_ΔW + end + end + + return Δy + end + + return (W, logJ), pullback_inv_link_chol_lkj +end + +function _inv_link_chol_lkj_rrule(y::AbstractMatrix) + K = LinearAlgebra.checksquare(y) + y_vec = Bijectors._triu_to_vec(y, 1) + W_logJ, back = _inv_link_chol_lkj_reverse(y_vec) + function pullback_inv_link_chol_lkj(ΔW_ΔlogJ) + return update_triu_from_vec(_triu_to_vec(back(ΔW_ΔlogJ), 1), 1, K) + end + + return W_logJ, pullback_inv_link_chol_lkj end function _logabsdetjac_inv_corr(Y::AbstractMatrix) diff --git a/src/chainrules.jl b/src/chainrules.jl index 3f598634..f15e1c22 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -267,55 +267,14 @@ function ChainRulesCore.rrule(::typeof(_link_chol_lkj_from_lower), W::AbstractMa end function ChainRulesCore.rrule(::typeof(_inv_link_chol_lkj), y::AbstractVector) - K = _triu1_dim_from_length(length(y)) - - W = similar(y, K, K) - - z_vec = similar(y) - tmp_vec = similar(y) - - idx = 1 - @inbounds for j in 1:K - W[1, j] = 1 - for i in 2:j - z = tanh(y[idx]) - tmp = W[i - 1, j] - - z_vec[idx] = z - tmp_vec[idx] = tmp - idx += 1 - - W[i - 1, j] = z * tmp - W[i, j] = tmp * sqrt(1 - z^2) - end - for i in (j + 1):K - W[i, j] = 0 - end - end - - function pullback_inv_link_chol_lkj(ΔW_thunked) - ΔW = ChainRulesCore.unthunk(ΔW_thunked) - - Δy = zero(y) - - @inbounds for j in 1:K - idx_up_to_prev_column = ((j - 1) * (j - 2) ÷ 2) - Δtmp = ΔW[j, j] - for i in j:-1:2 - idx = idx_up_to_prev_column + i - 1 - tmp = tmp_vec[idx] - z = z_vec[idx] - - Δz = ΔW[i - 1, j] * tmp - Δtmp * tmp / sqrt(1 - z^2) * z - Δy[idx] = Δz / cosh(y[idx])^2 - Δtmp = ΔW[i - 1, j] * z + Δtmp * sqrt(1 - z^2) - end - end + W_logJ, back = _inv_link_chol_lkj_rrule(y) + function pullback_inv_link_chol_lkj(ΔW_ΔlogJ) + Δy = back(ChainRulesCore.unthunk(ΔW_ΔlogJ)) return ChainRulesCore.NoTangent(), Δy end - return W, pullback_inv_link_chol_lkj + return W_logJ, pullback_inv_link_chol_lkj end function ChainRulesCore.rrule(::typeof(pd_from_upper), X::AbstractMatrix) diff --git a/src/utils.jl b/src/utils.jl index 82c15de6..9fd6c65c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -11,8 +11,14 @@ _vec(x::Real) = x lower_triangular(A::AbstractMatrix) = convert(typeof(A), LowerTriangular(A)) upper_triangular(A::AbstractMatrix) = convert(typeof(A), UpperTriangular(A)) -pd_from_lower(X) = LowerTriangular(X) * LowerTriangular(X)' -pd_from_upper(X) = UpperTriangular(X)' * UpperTriangular(X) +function pd_from_lower(X) + L = lower_triangular(X) + return L * L' +end +function pd_from_upper(X) + U = upper_triangular(X) + return U' * U +end # HACK: Allows us to define custom chain rules while we wait for upstream fixes. transpose_eager(X::AbstractMatrix) = permutedims(X) diff --git a/test/Project.toml b/test/Project.toml index a7a089c3..6f156b8b 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -2,6 +2,7 @@ ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" +Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" @@ -21,6 +22,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" ChainRulesTestUtils = "0.7, 1" ChangesOfVariables = "0.1" Combinatorics = "1.0.2" +Compat = "3.46, 4.2" DistributionsAD = "0.6.3" FillArrays = "1" FiniteDifferences = "0.11, 0.12" diff --git a/test/bijectors/product_bijector.jl b/test/bijectors/product_bijector.jl index 818f89c0..78310572 100644 --- a/test/bijectors/product_bijector.jl +++ b/test/bijectors/product_bijector.jl @@ -33,14 +33,27 @@ has_square_jacobian(b, x) = Bijectors.output_size(b, x) == size(x) end y, logjac = stack(map(first, results)), sum(last, results) - test_bijector( - b_prod, - x; - y, - logjac, - changes_of_variables_test=has_square_jacobian(b, xs[1]), - test_not_identity=!isidentity, - ) + if VERSION < v"1.9" && length(size(d)) > 0 + # `eachslice`, which is used by `ProductBijector`, is type-unstable + # for multivariate cases on Julia < 1.9. Hence the type-inference fails. + @test_broken test_bijector( + b_prod, + x; + y, + logjac, + changes_of_variables_test=has_square_jacobian(b, xs[1]), + test_not_identity=!isidentity, + ) + else + test_bijector( + b_prod, + x; + y, + logjac, + changes_of_variables_test=has_square_jacobian(b, xs[1]), + test_not_identity=!isidentity, + ) + end end @testset "Two-dim stack: $(nameof(typeof(d)))" for (d, isidentity) in ds @@ -57,13 +70,27 @@ has_square_jacobian(b, x) = Bijectors.output_size(b, x) == size(x) results = map(Base.Fix1(with_logabsdet_jacobian, b), xs) y, logjac = stack(map(first, results)), sum(last, results) - test_bijector( - b_prod, - x; - y, - logjac, - changes_of_variables_test=has_square_jacobian(b, xs[1]), - test_not_identity=!isidentity, - ) + if VERSION < v"1.9" && length(size(d)) > 0 + # `eachslice`, which is used by `ProductBijector`, does not support + # `dims` with more than one value. As a result, stacking anything that + # isn't univariate won't work here. + @test_broken test_bijector( + b_prod, + x; + y, + logjac, + changes_of_variables_test=has_square_jacobian(b, xs[1]), + test_not_identity=!isidentity, + ) + else + test_bijector( + b_prod, + x; + y, + logjac, + changes_of_variables_test=has_square_jacobian(b, xs[1]), + test_not_identity=!isidentity, + ) + end end end diff --git a/test/runtests.jl b/test/runtests.jl index 01b81fb8..2ce01010 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -29,6 +29,10 @@ using ChangesOfVariables: ChangesOfVariables using InverseFunctions: InverseFunctions using LazyArrays: LazyArrays +if VERSION < v"1.9" + using Compat: stack +end + const GROUP = get(ENV, "GROUP", "All") # Always include this since it can be useful for other tests.