From 41620db1f938fc862331a659872b7ca3b25de943 Mon Sep 17 00:00:00 2001 From: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Date: Thu, 14 Nov 2024 07:37:06 +0000 Subject: [PATCH 1/2] Improve `show` function of BUGSModel (#236) --- Project.toml | 2 +- src/JuliaBUGS.jl | 1 - src/model.jl | 54 ++++++++++++++++++++++++++++++++++-------------- 3 files changed, 40 insertions(+), 17 deletions(-) diff --git a/Project.toml b/Project.toml index 01f959bea..982689e6c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "JuliaBUGS" uuid = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf" -version = "0.6.4" +version = "0.6.5" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/JuliaBUGS.jl b/src/JuliaBUGS.jl index db888c352..d2add3de3 100644 --- a/src/JuliaBUGS.jl +++ b/src/JuliaBUGS.jl @@ -16,7 +16,6 @@ using StaticArrays import Base: ==, hash, Symbol, size import Distributions: truncated -import AbstractPPL: AbstractContext, evaluate!! export @bugs export compile, initialize! diff --git a/src/model.jl b/src/model.jl index 4c178d785..06f39eded 100644 --- a/src/model.jl +++ b/src/model.jl @@ -77,23 +77,47 @@ struct BUGSModel{base_model_T<:Union{<:AbstractBUGSModel,Nothing},T<:NamedTuple, end function Base.show(io::IO, model::BUGSModel) - if model.transformed - println( - io, - "BUGSModel (transformed, with dimension $(model.transformed_param_length)):", - "\n", - ) + # Print model type and dimension + space_type = + model.transformed ? "transformed (unconstrained)" : "original (constrained)" + dim = if model.transformed + model.transformed_param_length else - println( - io, - "BUGSModel (untransformed, with dimension $(model.untransformed_param_length)):", - "\n", - ) + model.untransformed_param_length + end + printstyled(io, "BUGSModel"; bold=true, color=:blue) + println(io, " (parameters are in ", space_type, " space, with dimension ", dim, "):\n") + + # Group and print parameters + printstyled(io, " Model parameters:\n"; bold=true, color=:yellow) + grouped_params = Dict{Symbol,Vector{VarName}}() + for param in model.parameters + sym = AbstractPPL.getsym(param) + push!(get!(grouped_params, sym, VarName[]), param) + end + for (sym, params) in grouped_params + param_str = length(params) == 1 ? string(params[1]) : "$(join(params, ", "))" + print(io, " ") + printstyled(io, param_str; color=:cyan) + println(io) + end + println(io) + + # Print variable info + printstyled(io, " Variable sizes and types:\n"; bold=true, color=:yellow) + for (name, value) in pairs(model.evaluation_env) + type_str = if isa(value, Number) + "type = $(typeof(value))" + else + "size = $(size(value)), type = $(typeof(value))" + end + print(io, " ") + printstyled(io, name; color=:cyan) + print(io, ": ") + printstyled(io, type_str; color=:green) + println(io) end - println(io, " Model parameters:") - println(io, " ", join(model.parameters, ", "), "\n") - println(io, " Variable values:") - return println(io, "$(model.evaluation_env)") + return nothing end """ From 26ef66241c4f8748a95dc990922046ad812a4bea Mon Sep 17 00:00:00 2001 From: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Date: Sat, 16 Nov 2024 12:16:36 +0000 Subject: [PATCH 2/2] Permit dot call (like `Distributions.Normal`) to be used in model definition (#237) The motivation for the PR is initially to allow the use [`product_distribution`](https://juliastats.org/Distributions.jl/stable/multivariate/#Distributions.product_distribution) in JuliaBUGS model. By writing ```julia @bugs begin x[1:2] ~ Distributions.product_distribution(fill(Normal(), 2)) end ``` It also enables ```julia julia> foo(x) = x + 1 foo (generic function with 1 method) julia> model_def = @bugs begin a = Main.foo(b) end quote a = Main.foo(b) end julia> compile(model_def, (;b=2)) BUGSModel (parameters are in transformed (unconstrained) space, with dimension 0): Model parameters: Variable sizes and types: a: type = Int64 b: type = Int64 ``` i.e., an easier way to introduce external functions into JuliaBUGS. This can't fully replace `@register_primitive` yet, because using `Main` module is relying on Julia runtime behavior and not very intuitive. Examples: ```julia julia> module TestModule bar(x) = x + 1 end Main.TestModule julia> model_def = @bugs begin a = TestModule.bar(b) end quote a = TestModule.bar(b) end julia> compile(model_def, (; b = 1)) ERROR: UndefVarError: `TestModule` not defined in `JuliaBUGS` ... ``` one would need to do ```julia julia> model_def = @bugs begin a = Main.TestModule.bar(b) end quote a = Main.TestModule.bar(b) end ``` also ```julia julia> @testset "t" begin foo1(x) = x + 1 model_def = @bugs begin a = Main.foo1(b) end compile(model_def, (; b= 1)) end t: Error During Test at REPL[18]:1 Got exception outside of a @test UndefVarError: `foo1` not defined in `Main` ... ``` --- Project.toml | 2 +- src/parser/bugs_macro.jl | 2 ++ test/compile.jl | 8 ++++++++ 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 982689e6c..f9783e7da 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "JuliaBUGS" uuid = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf" -version = "0.6.5" +version = "0.7.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/parser/bugs_macro.jl b/src/parser/bugs_macro.jl index b54287d00..d50982df5 100644 --- a/src/parser/bugs_macro.jl +++ b/src/parser/bugs_macro.jl @@ -151,6 +151,8 @@ function bugs_expression(expr, line_num) error( "Keyword argument syntax is not supported in BUGS, error at $line_num: $(expr)" ) + elseif Meta.isexpr(expr, :.) + return expr else error("Invalid expression at $line_num: `$expr`") end diff --git a/test/compile.jl b/test/compile.jl index 067df5f70..a2af7942a 100644 --- a/test/compile.jl +++ b/test/compile.jl @@ -87,3 +87,11 @@ end @test AbstractPPL.get(model_init_1.evaluation_env, @varname(beta)) == 1 end end + +@testset "dot call" begin + model_def = @bugs begin + x[1:2] ~ Distributions.product_distribution(fill(Distributions.Normal(0, 1), 2)) + end + model = compile(model_def, (;)) + @test model.evaluation_env.x isa Vector{Float64} +end