Skip to content

Commit

Permalink
Pass keywords to glmnet (#19)
Browse files Browse the repository at this point in the history
* actually pass keywords to glmnet

* test passing of keywords to glmnet
  • Loading branch information
tiemvanderdeure authored Aug 25, 2024
1 parent bf76f47 commit 93a2c31
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 5 deletions.
4 changes: 2 additions & 2 deletions src/lasso.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
function fit_lasso_path(
mm, presences;
wts, penalty_factor, λ, kw...)
weights, penalty_factor, lambda, kw...)

presence_matrix = [1 .- presences presences]
GLMNet.glmnet(
mm, presence_matrix, GLMNet.Binomial();
weights = wts, penalty_factor = penalty_factor, lambda = λ, standardize = false)
weights, penalty_factor, lambda, standardize = false, kw...)
end

get_coefs(path::GLMNet.GLMNetPath) = path.betas
5 changes: 2 additions & 3 deletions src/maxnet_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ function maxnet(
kw...
)
end
#maxnet(presences, predictors; kw...) = maxnet(presences, predictors, features; kw...)

### internal methods where features is not a keyword

Expand Down Expand Up @@ -129,10 +128,10 @@ function _maxnet(
weights = presences .* 1. .+ (1 .- presences) .* weight_factor

# generate lambdas
λ = lambdas(reg, presences, weights; λmax = 4, n = 200)
lambda = lambdas(reg, presences, weights; λmax = 4, n = 200)

# Fit the model
lassopath = fit_lasso_path(mm, presences, wts = weights, penalty_factor = reg, λ = λ)
lassopath = fit_lasso_path(mm, presences; weights, penalty_factor = reg, lambda, kw...)

# get the coefficients out
coefs = SparseArrays.sparse(get_coefs(lassopath)[:, end])
Expand Down
8 changes: 8 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ env1 = map(e -> [e[1]], env) # just the first row
@test predictors == (a = [1,2,3,1], b = [1,2,3,1])
end


@testset "Maxnet" begin
# some class combinations and keywords
m = maxnet(p_a, env; features = "lq");
Expand Down Expand Up @@ -72,7 +73,14 @@ end
@test complexity(empty_model) == 0
@test Maxnet.selected_features(empty_model) == Symbol[]
@test length(unique(predict(empty_model, env))) == 1

# test that keywords arguments are passed to glmnet
weights = ifelse.(p_a, 1.0, 10.0)
m_w = maxnet(p_a, env; features = "lq", addsamplestobackground = false, weights)
m = maxnet(p_a, env; features = "lq", addsamplestobackground = false)
@test m_w.entropy > m.entropy
end
m = maxnet(p_a, env; features = "lq", addsamplestobackground = false)

@testset "MLJ" begin
using MLJBase
Expand Down

0 comments on commit 93a2c31

Please sign in to comment.