diff --git a/src/bayesian/negativebinomial_regression.jl b/src/bayesian/negativebinomial_regression.jl index 20638ab..e2f8a47 100644 --- a/src/bayesian/negativebinomial_regression.jl +++ b/src/bayesian/negativebinomial_regression.jl @@ -16,7 +16,7 @@ end """ ```julia -fit(formula::FormulaTerm, data::DataFrame, modelClass::NegBinomRegression, prior::Prior_Ridge, h::Float64 = 0.1, sim_size::Int64 = 1000) +fit(formula::FormulaTerm, data::DataFrame, modelClass::NegBinomRegression, prior::Prior_Ridge, h::Float64 = 1.0, sim_size::Int64 = 1000) ``` Fit a Bayesian Negative Binomial Regression model on the input data with a Ridge prior. @@ -103,7 +103,7 @@ function fit( data::DataFrame, modelClass::NegBinomRegression, prior::Prior_Ridge, - h::Float64 = 0.1, + h::Float64 = 1.0, sim_size::Int64 = 1000 ) @model NegativeBinomialRegression(X, y) = begin @@ -218,7 +218,7 @@ function fit( data::DataFrame, modelClass::NegBinomRegression, prior::Prior_Laplace, - h::Float64 = 0.1, + h::Float64 = 1.0, sim_size::Int64 = 1000 ) @model NegativeBinomialRegression(X, y) = begin @@ -511,6 +511,7 @@ function fit( data::DataFrame, modelClass::NegBinomRegression, prior::Prior_HorseShoe, + h::Float64 = 1.0, sim_size::Int64 = 1000 ) @model NegativeBinomialRegression(X, y) = begin @@ -523,11 +524,11 @@ function fit( τ ~ halfcauchy ## Global Shrinkage λ ~ filldist(halfcauchy, p) ## Local Shrinkage - σ ~ halfcauchy + σ ~ InverseGamma(h, h) #α ~ Normal(0, τ * σ) β0 = repeat([0], p) ## prior mean - β ~ MvNormal(β0, λ * τ *σ) - + # β ~ MvNormal(β0, λ * τ *σ) + β ~ MvNormal(β0, λ * τ) ## link #z = α .+ X * β diff --git a/test/numerical/bayesian/LogisticRegression.jl b/test/numerical/bayesian/LogisticRegression.jl index 07aa246..470d66b 100644 --- a/test/numerical/bayesian/LogisticRegression.jl +++ b/test/numerical/bayesian/LogisticRegression.jl @@ -40,8 +40,8 @@ tests = [ ( Prior_HorseShoe(), ( - (Logit(), 0.38683395333332327), - (Probit(), 0.38253233489484173), + (Logit(), 0.7599999999740501), + (Probit(), 0.7580564600751047), (Cloglog(), 0.7667553778881738), (Cauchit(), 0.7706755564626601) )