Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nested transformation with Shift does not work for Matrix output #191

Closed
mgmverburg opened this issue Aug 1, 2021 · 1 comment
Closed

Comments

@mgmverburg
Copy link

Basically, what I want to achieve is a logit transformed outcome for example, but I want to allow covariates to have different effects, and hence shift the mean for each of the entries of the matrix.
I had a way to do this before already, but this involved an arraydist with a for-loop, and when optimizing I noticed that that approach actually caused a slowdown (among other things, it made type-instabilities appear I believe when using code_warntype).
Therefore, I wanted to find a way to do this in one shot. Most of the things work though, but for some reason this specific use-case seemed to not work, even though it felt like in theory it should.

To make it perhaps slightly weirder, the shift bijector does work with 2D data like in model_1 in the code below, which throws no error. However, when adding a layer like a logit transform to wrap around it, then it throws an error that I listed below the code.

using Bijectors, Turing, LinearAlgebra, using Random

M, N = 8, 20
output = rand(LogitNormal(0, 1), M, N)

@model function test_1(output, M, N)
    mvn = MvNormal(zeros(M), LinearAlgebra.I)
    z ~ filldist(Normal(0, 1), M, N)
    output ~ transformed(filldist(mvn, N), Bijectors.Shift(z))
end

model_1 = test_1(output, M, N)
chain_1 = sample(model_1, NUTS(0.65), 10)


@model function test_2(output, M, N)
    mvn = MvNormal(zeros(M), LinearAlgebra.I)
    b = inv(Bijectors.Logit{2}(0.0, 1.0))
    z ~ filldist(Normal(0, 1), M, N)
    output ~ transformed(transformed(filldist(mvn, N), Bijectors.Shift(z)), b)
end

model_2 = test_2(output, M, N)
chain_2 = sample(model_2, NUTS(0.65), 10)
Error message ERROR: MethodError: no method matching _logabsdetjac_shift(::Array{Float64,2}, ::Array{Float64,2}, ::Val{2}) Closest candidates are: _logabsdetjac_shift(::T1, ::AbstractArray{T2,2}, ::Val{2}) where {T1<:Union{Real, AbstractArray{T,1} where T}, T2<:Real} at /root/.julia/packages/Bijectors/LmARY/src/bijectors/shift.jl:36 _logabsdetjac_shift(::Union{Tracker.TrackedArray{var"#s25",1,A} where A where var"#s25"<:Real, Tracker.TrackedReal}, ::AbstractArray{var"#s24",2} where var"#s24"<:Real, ::Val{1}) at /root/.julia/packages/Bijectors/LmARY/src/compat/tracker.jl:80 _logabsdetjac_shift(::T1, ::AbstractArray{T2,2}, ::Val{1}) where {T1<:Union{Real, AbstractArray{T,1} where T}, T2<:Real} at /root/.julia/packages/Bijectors/LmARY/src/bijectors/shift.jl:35 ... Stacktrace: [1] logabsdetjac(::Bijectors.Shift{Array{Float64,2},2}, ::Array{Float64,2}) at /root/.julia/packages/Bijectors/LmARY/src/bijectors/shift.jl:30 [2] logpdf_with_trans(::Bijectors.TransformedDistribution{DistributionsAD.VectorOfMultivariate{Continuous,MvNormal{Float64,PDMats.ScalMat{Float64},Array{Float64,1}},FillArrays.Fill{MvNormal{Float64,PDMats.ScalMat{Float64},Array{Float64,1}},1,Tuple{Base.OneTo{Int64}}}},Bijectors.Shift{Array{Float64,2},2},Matrixvariate}, ::Array{Float64,2}, ::Bool) at /root/.julia/packages/Bijectors/LmARY/src/Bijectors.jl:132 [3] _logpdf(::Bijectors.TransformedDistribution{Bijectors.TransformedDistribution{DistributionsAD.VectorOfMultivariate{Continuous,MvNormal{Float64,PDMats.ScalMat{Float64},Array{Float64,1}},FillArrays.Fill{MvNormal{Float64,PDMats.ScalMat{Float64},Array{Float64,1}},1,Tuple{Base.OneTo{Int64}}}},Bijectors.Shift{Array{Float64,2},2},Matrixvariate},Inverse{Bijectors.Logit{2,Float64},2},Matrixvariate}, ::Array{Float64,2}) at /root/.julia/packages/Bijectors/LmARY/src/transformed_distribution.jl:124 [4] logpdf(::Bijectors.TransformedDistribution{Bijectors.TransformedDistribution{DistributionsAD.VectorOfMultivariate{Continuous,MvNormal{Float64,PDMats.ScalMat{Float64},Array{Float64,1}},FillArrays.Fill{MvNormal{Float64,PDMats.ScalMat{Float64},Array{Float64,1}},1,Tuple{Base.OneTo{Int64}}}},Bijectors.Shift{Array{Float64,2},2},Matrixvariate},Inverse{Bijectors.Logit{2,Float64},2},Matrixvariate}, ::Array{Float64,2}) at /root/.julia/packages/Distributions/Xrm9e/src/matrixvariates.jl:164 [5] loglikelihood(::Bijectors.TransformedDistribution{Bijectors.TransformedDistribution{DistributionsAD.VectorOfMultivariate{Continuous,MvNormal{Float64,PDMats.ScalMat{Float64},Array{Float64,1}},FillArrays.Fill{MvNormal{Float64,PDMats.ScalMat{Float64},Array{Float64,1}},1,Tuple{Base.OneTo{Int64}}}},Bijectors.Shift{Array{Float64,2},2},Matrixvariate},Inverse{Bijectors.Logit{2,Float64},2},Matrixvariate}, ::Array{Float64,2}) at /root/.julia/packages/Distributions/Xrm9e/src/matrixvariates.jl:227 [6] observe(::DynamicPPL.SampleFromUniform, ::Bijectors.TransformedDistribution{Bijectors.TransformedDistribution{DistributionsAD.VectorOfMultivariate{Continuous,MvNormal{Float64,PDMats.ScalMat{Float64},Array{Float64,1}},FillArrays.Fill{MvNormal{Float64,PDMats.ScalMat{Float64},Array{Float64,1}},1,Tuple{Base.OneTo{Int64}}}},Bijectors.Shift{Array{Float64,2},2},Matrixvariate},Inverse{Bijectors.Logit{2,Float64},2},Matrixvariate}, ::Array{Float64,2}, ::DynamicPPL.ThreadSafeVarInfo{DynamicPPL.VarInfo{DynamicPPL.Metadata{Dict{AbstractPPL.VarName,Int64},Array{Distribution,1},Array{AbstractPPL.VarName,1},Array{Real,1},Array{Set{DynamicPPL.Selector},1}},Float64},Array{Base.RefValue{Float64},1}}) at /root/.julia/packages/DynamicPPL/wCsuo/src/context_implementations.jl:148 [7] _tilde(::DynamicPPL.SampleFromUniform, ::Bijectors.TransformedDistribution{Bijectors.TransformedDistribution{DistributionsAD.VectorOfMultivariate{Continuous,MvNormal{Float64,PDMats.ScalMat{Float64},Array{Float64,1}},FillArrays.Fill{MvNormal{Float64,PDMats.ScalMat{Float64},Array{Float64,1}},1,Tuple{Base.OneTo{Int64}}}},Bijectors.Shift{Array{Float64,2},2},Matrixvariate},Inverse{Bijectors.Logit{2,Float64},2},Matrixvariate}, ::Array{Float64,2}, ::DynamicPPL.ThreadSafeVarInfo{DynamicPPL.VarInfo{DynamicPPL.Metadata{Dict{AbstractPPL.VarName,Int64},Array{Distribution,1},Array{AbstractPPL.VarName,1},Array{Real,1},Array{Set{DynamicPPL.Selector},1}},Float64},Array{Base.RefValue{Float64},1}}) at /root/.julia/packages/DynamicPPL/wCsuo/src/context_implementations.jl:112 [8] tilde(::DynamicPPL.DefaultContext, ::DynamicPPL.SampleFromUniform, ::Bijectors.TransformedDistribution{Bijectors.TransformedDistribution{DistributionsAD.VectorOfMultivariate{Continuous,MvNormal{Float64,PDMats.ScalMat{Float64},Array{Float64,1}},FillArrays.Fill{MvNormal{Float64,PDMats.ScalMat{Float64},Array{Float64,1}},1,Tuple{Base.OneTo{Int64}}}},Bijectors.Shift{Array{Float64,2},2},Matrixvariate},Inverse{Bijectors.Logit{2,Float64},2},Matrixvariate}, ::Array{Float64,2}, ::DynamicPPL.ThreadSafeVarInfo{DynamicPPL.VarInfo{DynamicPPL.Metadata{Dict{AbstractPPL.VarName,Int64},Array{Distribution,1},Array{AbstractPPL.VarName,1},Array{Real,1},Array{Set{DynamicPPL.Selector},1}},Float64},Array{Base.RefValue{Float64},1}}) at /root/.julia/packages/DynamicPPL/wCsuo/src/context_implementations.jl:68 [9] tilde_observe(::DynamicPPL.DefaultContext, ::DynamicPPL.SampleFromUniform, ::Bijectors.TransformedDistribution{Bijectors.TransformedDistribution{DistributionsAD.VectorOfMultivariate{Continuous,MvNormal{Float64,PDMats.ScalMat{Float64},Array{Float64,1}},FillArrays.Fill{MvNormal{Float64,PDMats.ScalMat{Float64},Array{Float64,1}},1,Tuple{Base.OneTo{Int64}}}},Bijectors.Shift{Array{Float64,2},2},Matrixvariate},Inverse{Bijectors.Logit{2,Float64},2},Matrixvariate}, ::Array{Float64,2}, ::AbstractPPL.VarName{:output,Tuple{}}, ::Tuple{}, ::DynamicPPL.ThreadSafeVarInfo{DynamicPPL.VarInfo{DynamicPPL.Metadata{Dict{AbstractPPL.VarName,Int64},Array{Distribution,1},Array{AbstractPPL.VarName,1},Array{Real,1},Array{Set{DynamicPPL.Selector},1}},Float64},Array{Base.RefValue{Float64},1}}) at /root/.julia/packages/DynamicPPL/wCsuo/src/context_implementations.jl:93 [10] #3 at ./REPL[9]:5 [inlined] [11] (::var"#3#4")(::Random._GLOBAL_RNG, ::DynamicPPL.Model{var"#3#4",(:output, :M, :N),(),(),Tuple{Array{Float64,2},Int64,Int64},Tuple{}}, ::DynamicPPL.ThreadSafeVarInfo{DynamicPPL.VarInfo{DynamicPPL.Metadata{Dict{AbstractPPL.VarName,Int64},Array{Distribution,1},Array{AbstractPPL.VarName,1},Array{Real,1},Array{Set{DynamicPPL.Selector},1}},Float64},Array{Base.RefValue{Float64},1}}, ::DynamicPPL.SampleFromUniform, ::DynamicPPL.DefaultContext, ::Array{Float64,2}, ::Int64, ::Int64) at ./none:0 [12] macro expansion at /root/.julia/packages/DynamicPPL/wCsuo/src/model.jl:0 [inlined] [13] _evaluate(::Random._GLOBAL_RNG, ::DynamicPPL.Model{var"#3#4",(:output, :M, :N),(),(),Tuple{Array{Float64,2},Int64,Int64},Tuple{}}, ::DynamicPPL.ThreadSafeVarInfo{DynamicPPL.VarInfo{DynamicPPL.Metadata{Dict{AbstractPPL.VarName,Int64},Array{Distribution,1},Array{AbstractPPL.VarName,1},Array{Real,1},Array{Set{DynamicPPL.Selector},1}},Float64},Array{Base.RefValue{Float64},1}}, ::DynamicPPL.SampleFromUniform, ::DynamicPPL.DefaultContext) at /root/.julia/packages/DynamicPPL/wCsuo/src/model.jl:150 [14] evaluate_threadsafe(::Random._GLOBAL_RNG, ::DynamicPPL.Model{var"#3#4",(:output, :M, :N),(),(),Tuple{Array{Float64,2},Int64,Int64},Tuple{}}, ::DynamicPPL.VarInfo{DynamicPPL.Metadata{Dict{AbstractPPL.VarName,Int64},Array{Distribution,1},Array{AbstractPPL.VarName,1},Array{Real,1},Array{Set{DynamicPPL.Selector},1}},Float64}, ::DynamicPPL.SampleFromUniform, ::DynamicPPL.DefaultContext) at /root/.julia/packages/DynamicPPL/wCsuo/src/model.jl:140 [15] (::DynamicPPL.Model{var"#3#4",(:output, :M, :N),(),(),Tuple{Array{Float64,2},Int64,Int64},Tuple{}})(::Random._GLOBAL_RNG, ::DynamicPPL.VarInfo{DynamicPPL.Metadata{Dict{AbstractPPL.VarName,Int64},Array{Distribution,1},Array{AbstractPPL.VarName,1},Array{Real,1},Array{Set{DynamicPPL.Selector},1}},Float64}, ::DynamicPPL.SampleFromUniform, ::DynamicPPL.DefaultContext) at /root/.julia/packages/DynamicPPL/wCsuo/src/model.jl:94 [16] DynamicPPL.VarInfo(::Random._GLOBAL_RNG, ::DynamicPPL.Model{var"#3#4",(:output, :M, :N),(),(),Tuple{Array{Float64,2},Int64,Int64},Tuple{}}, ::DynamicPPL.SampleFromUniform, ::DynamicPPL.DefaultContext) at /root/.julia/packages/DynamicPPL/wCsuo/src/varinfo.jl:132 [17] DynamicPPL.VarInfo(::Random._GLOBAL_RNG, ::DynamicPPL.Model{var"#3#4",(:output, :M, :N),(),(),Tuple{Array{Float64,2},Int64,Int64},Tuple{}}, ::DynamicPPL.SampleFromUniform) at /root/.julia/packages/DynamicPPL/wCsuo/src/varinfo.jl:131 [18] step(::Random._GLOBAL_RNG, ::DynamicPPL.Model{var"#3#4",(:output, :M, :N),(),(),Tuple{Array{Float64,2},Int64,Int64},Tuple{}}, ::DynamicPPL.Sampler{NUTS{Turing.Core.ForwardDiffAD{40},(),AdvancedHMC.DiagEuclideanMetric}}; resume_from::Nothing, kwargs::Base.Iterators.Pairs{Symbol,Int64,Tuple{Symbol},NamedTuple{(:nadapts,),Tuple{Int64}}}) at /root/.julia/packages/DynamicPPL/wCsuo/src/sampler.jl:69 [19] macro expansion at /root/.julia/packages/AbstractMCMC/ByHEr/src/sample.jl:123 [inlined] [20] macro expansion at /root/.julia/packages/ProgressLogging/6KXlp/src/ProgressLogging.jl:328 [inlined] [21] (::AbstractMCMC.var"#21#22"{Bool,String,Nothing,Int64,Int64,Base.Iterators.Pairs{Symbol,Int64,Tuple{Symbol},NamedTuple{(:nadapts,),Tuple{Int64}}},Random._GLOBAL_RNG,DynamicPPL.Model{var"#3#4",(:output, :M, :N),(),(),Tuple{Array{Float64,2},Int64,Int64},Tuple{}},DynamicPPL.Sampler{NUTS{Turing.Core.ForwardDiffAD{40},(),AdvancedHMC.DiagEuclideanMetric}},Int64,Int64})() at /root/.julia/packages/AbstractMCMC/ByHEr/src/logging.jl:11 [22] with_logstate(::Function, ::Any) at ./logging.jl:408 [23] with_logger(::Function, ::LoggingExtras.TeeLogger{Tuple{LoggingExtras.EarlyFilteredLogger{TerminalLoggers.TerminalLogger,AbstractMCMC.var"#1#3"{Module}},LoggingExtras.EarlyFilteredLogger{Logging.ConsoleLogger,AbstractMCMC.var"#2#4"{Module}}}}) at ./logging.jl:514 [24] with_progresslogger(::Function, ::Module, ::Logging.ConsoleLogger) at /root/.julia/packages/AbstractMCMC/ByHEr/src/logging.jl:34 [25] macro expansion at /root/.julia/packages/AbstractMCMC/ByHEr/src/logging.jl:10 [inlined] [26] mcmcsample(::Random._GLOBAL_RNG, ::DynamicPPL.Model{var"#3#4",(:output, :M, :N),(),(),Tuple{Array{Float64,2},Int64,Int64},Tuple{}}, ::DynamicPPL.Sampler{NUTS{Turing.Core.ForwardDiffAD{40},(),AdvancedHMC.DiagEuclideanMetric}}, ::Int64; progress::Bool, progressname::String, callback::Nothing, discard_initial::Int64, thinning::Int64, chain_type::Type{T} where T, kwargs::Base.Iterators.Pairs{Symbol,Int64,Tuple{Symbol},NamedTuple{(:nadapts,),Tuple{Int64}}}) at /root/.julia/packages/AbstractMCMC/ByHEr/src/sample.jl:114 [27] sample(::Random._GLOBAL_RNG, ::DynamicPPL.Model{var"#3#4",(:output, :M, :N),(),(),Tuple{Array{Float64,2},Int64,Int64},Tuple{}}, ::DynamicPPL.Sampler{NUTS{Turing.Core.ForwardDiffAD{40},(),AdvancedHMC.DiagEuclideanMetric}}, ::Int64; chain_type::Type{T} where T, resume_from::Nothing, progress::Bool, nadapts::Int64, discard_adapt::Bool, discard_initial::Int64, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /root/.julia/packages/Turing/TbEZL/src/inference/hmc.jl:133 [28] sample at /root/.julia/packages/Turing/TbEZL/src/inference/hmc.jl:116 [inlined] [29] #sample#2 at /root/.julia/packages/Turing/TbEZL/src/inference/Inference.jl:142 [inlined] [30] sample at /root/.julia/packages/Turing/TbEZL/src/inference/Inference.jl:142 [inlined] [31] #sample#1 at /root/.julia/packages/Turing/TbEZL/src/inference/Inference.jl:132 [inlined] [32] sample(::DynamicPPL.Model{var"#3#4",(:output, :M, :N),(),(),Tuple{Array{Float64,2},Int64,Int64},Tuple{}}, ::NUTS{Turing.Core.ForwardDiffAD{40},(),AdvancedHMC.DiagEuclideanMetric}, ::Int64) at /root/.julia/packages/Turing/TbEZL/src/inference/Inference.jl:132 [33] top-level scope at REPL[11]:1

So I was able to fix this (for my specific case that I encountered an error with) by simply adding:
Bijectors._logabsdetjac_shift(a::T1, x::AbstractMatrix{T2}, ::Val{2}) where {T1<:Union{Real, AbstractMatrix}, T2<:Real} = zero(T2)

But I am not sure if that is the best/cleanest fix for the package as a whole, or whether this covers just 1 use-case again.

Bijectors version 0.9.7, Turing 0.16.0

@torfjelde
Copy link
Member

Ah yes this is a missing definition. But you're solution is correct 👍 Once #183 has gone through, these things shouldn't happen.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants