Skip to content

Commit

Permalink
add ability to set multi-valued booster params (#164)
Browse files Browse the repository at this point in the history
  • Loading branch information
ExpandingMan authored Feb 3, 2023
1 parent 22a0239 commit a0616f1
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "XGBoost"
uuid = "009559a3-9522-5dbb-924b-0b6ed2b22bb9"
version = "2.2.2"
version = "2.2.3"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
9 changes: 9 additions & 0 deletions src/booster.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,15 @@ end
setparam!(b::Booster, name::AbstractString, val) = setparam!(b, name, string(val))
setparam!(b::Booster, name::Symbol, val) = setparam!(b, string(name), val)

setmultiparams!(b::Booster, name::Union{Symbol,AbstractString}, vals) = foreach(v -> setparam!(b, name, v), vals)

# the API for some parameters involves multiple separate calls to XGBoosterSetParam
# multi methods for resolving ambiguities
setparam!(b::Booster, name::Symbol, vals::AbstractVector) = setmultiparams!(b, name, vals)
setparam!(b::Booster, name::AbstractString, vals::AbstractVector) = setmultiparams!(b, name, vals)
setparam!(b::Booster, name::Symbol, vals::Tuple) = setmultiparams!(b, name, vals)
setparam!(b::Booster, name::AbstractString, vals::Tuple) = setmultiparams!(b, name, vals)

"""
setparams!(b::Booster; kw...)
Expand Down
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ end
watchlist=watchlist,
η=1, max_depth=2,
objective="binary:logistic",
# check that we can set multiple param values
eval_metric=["rmse", "rmsle"],
)
end

Expand Down Expand Up @@ -171,6 +173,7 @@ end
η=1.0, max_depth=2,
objective="binary:logistic",
watchlist=Dict(),
eval_metric=("mae", "mape"),
)
preds = predict(bst, dtest)
XGBoost.save(bst, model_file)
Expand Down

0 comments on commit a0616f1

Please sign in to comment.