Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Moving to Bijectors #125

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@ version = "0.9.3"
[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DynamicIterators = "6c76993d-992e-5bf1-9e63-34920a5a5a38"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
InfiniteArrays = "4858937d-0d70-526a-a4dd-2d5cb5dd786c"
InlineTest = "bd334432-b1e7-49c7-a2dc-dd9149e4ebd6"
KeywordCalls = "4d827475-d3e4-43d6-abe3-9688362ede9f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
MappedArrays = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900"
Expand All @@ -22,11 +25,14 @@ NamedTupleTools = "d9ec5142-1e00-5aa0-9d6a-321866360f50"
NestedTuples = "a734d2a7-8d68-409b-9419-626914d4061d"
PositiveFactorizations = "85a6dd25-e78a-55b7-8502-1745935b8125"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReTest = "e0db7c4e-2690-44b9-bad6-7687da720f89"
SimpleTraits = "699a6c99-e7fa-54fc-8d76-47d257e15c1d"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
TransformVariables = "84d833dd-6860-57f9-a1a7-6da5db126cff"
StrideArrays = "d1fa6d79-ef01-42a6-86c9-f7c551f8593b"
StrideArraysCore = "7792a7ef-975c-4747-a70f-980b88e8d1da"
Tricks = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775"
Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"

Expand All @@ -51,7 +57,6 @@ SimpleTraits = "0.9"
SpecialFunctions = "0.10, 1"
StaticArrays = "0.12, 1"
StatsFuns = "0.9"
TransformVariables = "0.4"
Tricks = "0.1.4"
Tullio = "0.2, 0.3"
julia = "1.5"
Expand Down
18 changes: 11 additions & 7 deletions src/MeasureTheory.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
module MeasureTheory

using LoopVectorization: ArrayInterface
using InlineTest

using Random

using ConcreteStructs
using MLStyle
using NestedTuples
using TransformVariables
const TV = TransformVariables
# using TransformVariables
# const TV = TransformVariables

import Base
import Distributions
Expand All @@ -22,6 +25,7 @@ using DynamicIterators
using KeywordCalls
using ConstructionBase
using Accessors
using ArrayInterface

const ∞ = InfiniteArrays.∞

Expand All @@ -44,7 +48,7 @@ is understood to be `basemeasure(μ)`.
"""
function logdensity end

include("const.jl")
# include("const.jl")
include("exp.jl")
include("domains.jl")
include("utils.jl")
Expand All @@ -65,7 +69,7 @@ include("combinators/superpose.jl")
include("combinators/product.jl")
include("combinators/for.jl")
include("combinators/power.jl")
include("combinators/transforms.jl")
# include("combinators/transforms.jl")
include("combinators/spikemixture.jl")
include("combinators/chain.jl")
include("kernel.jl")
Expand All @@ -90,11 +94,11 @@ include("parameterized/bernoulli.jl")
include("parameterized/poisson.jl")
include("parameterized/binomial.jl")
include("parameterized/multinomial.jl")
include("parameterized/lkj-cholesky.jl")
# include("parameterized/lkj-cholesky.jl")
include("parameterized/negativebinomial.jl")

include("transforms/corrcholesky.jl")
include("transforms/ordered.jl")
# include("transforms/corrcholesky.jl")
# include("transforms/ordered.jl")

include("density.jl")
# include("pushforward.jl")
Expand Down
16 changes: 8 additions & 8 deletions src/combinators/product.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,10 @@ function logdensity(d::ProductMeasure, x)
end


function TV.as(d::ProductMeasure{F,A}) where {F,A<:AbstractArray}
d1 = marginals(d).f(first(marginals(d).data))
as(Array, as(d1), size(marginals(d))...)
end
# function TV.as(d::ProductMeasure{F,A}) where {F,A<:AbstractArray}
# d1 = marginals(d).f(first(marginals(d).data))
# as(Array, as(d1), size(marginals(d))...)
# end

function Base.show(io::IO, ::MIME"text/plain", d::ProductMeasure{F,A}) where {F,A<:AbstractArray}
io = IOContext(io, :compact => true)
Expand All @@ -109,10 +109,10 @@ end
###############################################################################
# I <: Base.Generator

function TV.as(d::ProductMeasure{F,I}) where {F, I<:Base.Generator}
d1 = marginals(d).f(first(marginals(d).iter))
as(Array, as(d1), size(marginals(d))...)
end
# function TV.as(d::ProductMeasure{F,I}) where {F, I<:Base.Generator}
# d1 = marginals(d).f(first(marginals(d).iter))
# as(Array, as(d1), size(marginals(d))...)
# end


export rand!
Expand Down
6 changes: 3 additions & 3 deletions src/combinators/transforms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ basemeasure(μ::Pullback) = Pullback(μ.f, basemeasure(μ.ν), false)

basemeasure(ν::Pushforward) = Pushforward(ν.f, basemeasure(ν.μ), false)

TV.as(ν::Pushforward) = ν.f ∘ as(ν.μ)
# TV.as(ν::Pushforward) = ν.f ∘ as(ν.μ)

TV.as(μ::Pullback) = inverse(μ.f) ∘ μ.ν
# TV.as(μ::Pullback) = inverse(μ.f) ∘ μ.ν

TV.as(::Lebesgue) = asℝ
# TV.as(::Lebesgue) = asℝ


basemeasure(::Pushforward{TV.CallableTransform{T}, Lebesgue{ℝ}}) where {T <: TV.ScalarTransform} = Lebesgue(ℝ)
Expand Down
2 changes: 1 addition & 1 deletion src/combinators/weighted.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ logweight(μ::AbstractWeightedMeasure) = μ.logweight
basemeasure(μ::AbstractWeightedMeasure) = μ.base


TV.as(μ::AbstractWeightedMeasure) = TV.as(μ.base)
# TV.as(μ::AbstractWeightedMeasure) = TV.as(μ.base)

function logdensity(sm::AbstractWeightedMeasure, x)
logdensity(sm.base, x) + sm.logweight
Expand Down
4 changes: 2 additions & 2 deletions src/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ function _half(__module__, ex)
:($dist($(args...))) => begin
halfdist = Symbol(:Half, dist)

TV = TransformVariables
# TV = TransformVariables

quote

Expand All @@ -276,7 +276,7 @@ function _half(__module__, ex)
return abs(rand(rng, T, unhalf(μ)))
end

$TV.as(::$halfdist) = asℝ₊
# $TV.as(::$halfdist) = asℝ₊

(::$halfdist ≪ ::Lebesgue{ℝ₊}) = true
end
Expand Down
46 changes: 24 additions & 22 deletions src/parameterized.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using TransformVariables
# using TransformVariables

export ParameterizedMeasure
abstract type ParameterizedMeasure{N} <: AbstractMeasure end
Expand Down Expand Up @@ -77,29 +77,31 @@ julia> asparams(Normal{(:μ,:σ)})
TransformVariables.TransformTuple{NamedTuple{(:μ, :σ), Tuple{TransformVariables.Identity, TransformVariables.ShiftedExp{true, Float64}}}}((μ = asℝ, σ = asℝ₊), 2)
```
"""
function asparams end

asparams(μ::ParameterizedMeasure, v::Val) = asparams(constructor(μ), v)
asparams(μ, s::Symbol) = asparams(μ, Val(s))

asparams(M::Type{PM}) where {PM<:ParameterizedMeasure} = asparams(M, NamedTuple())

function asparams(::Type{M}, constraints::NamedTuple{N2}) where {N1, N2, M<: ParameterizedMeasure{N1}}
# @show M
thekeys = params(M, constraints)
t1 = NamedTuple{thekeys}(asparams(M, Val(k)) for k in thekeys)
t2 = NamedTuple{N2}(map(asConst, values(constraints)))
C = constructorof(M)
# @show C
# @show constraints
# @show transforms
# Make sure we end up with a consistent ordering
ordered_transforms = params(C(merge(t1, t2)))
return TV.as(ordered_transforms)
end


asparams(μ::ParameterizedMeasure, nt::NamedTuple=NamedTuple()) = asparams(constructor(μ), nt)
# function asparams end

# asparams(μ::ParameterizedMeasure, v::Val) = asparams(constructor(μ), v)
# asparams(μ, s::Symbol) = asparams(μ, Val(s))

# asparams(M::Type{PM}) where {PM<:ParameterizedMeasure} = asparams(M, NamedTuple())

# function asparams(::Type{M}, constraints::NamedTuple{N2}) where {N1, N2, M<: ParameterizedMeasure{N1}}
# # @show M
# thekeys = params(M, constraints)
# t1 = NamedTuple{thekeys}(asparams(M, Val(k)) for k in thekeys)
# t2 = NamedTuple{N2}(map(asConst, values(constraints)))
# C = constructorof(M)
# # @show C
# # @show constraints
# # @show transforms
# # Make sure we end up with a consistent ordering
# ordered_transforms = params(C(merge(t1, t2)))
# return TV.as(ordered_transforms)
# end


# asparams(μ::ParameterizedMeasure, nt::NamedTuple=NamedTuple()) = asparams(constructor(μ), nt)

export params

Expand Down
2 changes: 1 addition & 1 deletion src/parameterized/beta.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ export Beta
beta => β
]

TV.as(::Beta) = as𝕀
# TV.as(::Beta) = as𝕀

function logdensity(d::Beta{(:α, :β)}, x)
return (d.α - 1) * log(x) + (d.β - 1) * log(1 - x) - logbeta(d.α, d.β)
Expand Down
2 changes: 1 addition & 1 deletion src/parameterized/cauchy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Base.rand(rng::AbstractRNG, T::Type, μ::Cauchy{()}) = randn(rng, T) / randn(rng

≪(::Cauchy, ::Lebesgue{X}) where X <: Real = true

TV.as(::Cauchy) = asℝ
# TV.as(::Cauchy) = asℝ

@half Cauchy()

Expand Down
2 changes: 1 addition & 1 deletion src/parameterized/dirichlet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ export Dirichlet

@parameterized Dirichlet(α)

TV.as(d::Dirichlet{(:α,)}) = TV.UnitSimplex(length(d.α))
# TV.as(d::Dirichlet{(:α,)}) = TV.UnitSimplex(length(d.α))

function basemeasure(μ::Dirichlet{(:α,)})
t = as(μ)
Expand Down
2 changes: 1 addition & 1 deletion src/parameterized/exponential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ function Base.rand(rng::AbstractRNG, T::Type, d::Exponential{(:λ,)})
randexp(rng, T) / d.λ
end

TV.as(::Exponential) = asℝ₊
# TV.as(::Exponential) = asℝ₊

function logdensity(d::Exponential{(:λ,)}, x)
z = x * d.λ
Expand Down
2 changes: 1 addition & 1 deletion src/parameterized/gumbel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ function Base.rand(rng::AbstractRNG, d::Gumbel{()})
-log(-log(u))
end

TV.as(::Gumbel) = asℝ
# TV.as(::Gumbel) = asℝ

≪(::Gumbel, ::Lebesgue{X}) where X <: Real = true

Expand Down
2 changes: 1 addition & 1 deletion src/parameterized/inverse-gamma.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@ Base.rand(rng::AbstractRNG, T::Type, μ::InverseGamma{(:shape,)}) = rand(rng, Di

≪(::InverseGamma, ::Lebesgue{X}) where X <: Real = true

TV.as(::InverseGamma) = asℝ₊
# TV.as(::InverseGamma) = asℝ₊

@μσ_methods InverseGamma(shape)
2 changes: 1 addition & 1 deletion src/parameterized/laplace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Base.rand(rng::AbstractRNG, μ::Laplace{()}) = rand(rng, Dists.Laplace())

≪(::Laplace, ::Lebesgue{X}) where X <: Real = true

TV.as(::Laplace) = asℝ
# TV.as(::Laplace) = asℝ

@μσ_methods Laplace()
@half Laplace()
Expand Down
2 changes: 1 addition & 1 deletion src/parameterized/lkj-cholesky.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ function logdensity(d::LKJCholesky{(:k, :logη)}, L::Union{LinearAlgebra.Abstrac
end


TV.as(d::LKJCholesky) = CorrCholesky(d.k)
# TV.as(d::LKJCholesky) = CorrCholesky(d.k)

function basemeasure(μ::LKJCholesky{(:k,:η)})
t = as(μ)
Expand Down
Loading