diff --git a/src/Maxnet.jl b/src/Maxnet.jl index 776648e..f01db6c 100644 --- a/src/Maxnet.jl +++ b/src/Maxnet.jl @@ -11,6 +11,7 @@ export IdentityLink, CloglogLink, LogitLink, LogLink # re-export relevant links export LassoBackend, GLMNetBackend export maxnet, predict export LinearFeature, CategoricalFeature, QuadraticFeature, ProductFeature, ThresholdFeature, HingeFeature +export MaxnetBinaryClassifier # Write your package code here. diff --git a/src/feature_classes.jl b/src/feature_classes.jl index b23a97c..e639b94 100644 --- a/src/feature_classes.jl +++ b/src/feature_classes.jl @@ -26,6 +26,17 @@ function features_from_string(s::AbstractString) end # Default features based on number of presences + +""" + default_features(np) + +Takes the number of presences `np` and returns a `Vector` of `AbstractFeatureClass`s that are used my maxent as default. + +If `np` is less than ten, then only `LinearFeature` and `CategoricalFeature` are used. +If it is at least 10, then `QuadraticFeature` is additionally used. +If it is at least 15, then `HingeFeature` is additionally used. +If it is at least 80, then `ProductFeature` is additionally used. +""" function default_features(np) features = [LinearFeature(), CategoricalFeature()] if np >= 10 diff --git a/src/maxnet_function.jl b/src/maxnet_function.jl index 1682c08..e0ddf26 100644 --- a/src/maxnet_function.jl +++ b/src/maxnet_function.jl @@ -23,7 +23,8 @@ end """ maxnet( - presences, predictors, [features]; + presences, predictors; + [features], regularization_multiplier, regularization_function, addsamplestobackground, weight_factor, backend, kw... @@ -36,20 +37,28 @@ end # Keywords - `features`: Either: - A `Vector` of `AbstractFeatureClass` type features; or - - A string where "l" = linear and categorical, "q" = quadratic, "p" = product, "t" = threshold, "h" = hinge; or - - Nothing, in which case the default features based on the number of presences are used + - A `String` where "l" = linear and categorical, "q" = quadratic, "p" = product, "t" = threshold, "h" = hinge (e.g. "lqh"); or + - The default, in which case the features are based on the number of presences are used. See [`default_features`](@ref) - `regularization_multiplier`: A constant to adjust regularization, where a higher `regularization_multiplier` results in a higher penalization for features - `regularization_function`: A function to compute a regularization for each feature. A default `regularization_function` is built in. - `addsamplestobackground`: A boolean, where `true` adds the background samples to the predictors. Defaults to `true`. - `n_knots`: the number of knots used for Threshold and Hinge features. Defaults to 50. Ignored if there are neither Threshold nor Hinge features -- `weight_factor`: A `Float64` to adjust the weight of the background samples. Defaults to 100.0. -- `backend`: Either `LassoBackend()` or `GLMNetBackend()`, to use Lasso.jl or GLMNet.jl fit the model. +- `weight_factor`: A `Float64` value to adjust the weight of the background samples. Defaults to 100.0. +- `backend`: Either `LassoBackend()` or `GLMNetBackend()`, to use either Lasso.jl or GLMNet.jl to fit the model. Lasso.jl is written in pure julia, but can be slower with large model matrices (e.g. when hinge is enabled). Defaults to `LassoBackend`. -- `kw...`: Further arguments to be passed to Lasso.fit or GLMNet.glmnet +- `kw...`: Further arguments to be passed to `Lasso.fit` or `GLMNet.glmnet` # Returns - `model`: A model of type `MaxnetModel` +# Examples +```jldoctest + using Maxnet + p_a, env = Maxnet.bradypus() + + bradypus_model = maxnet(p_a, env; features = "lq", backend = GLMNetBackend()); +``` + """ function maxnet( presences::BitVector, predictors; diff --git a/src/mlj_interface.jl b/src/mlj_interface.jl index e04d302..0ead673 100644 --- a/src/mlj_interface.jl +++ b/src/mlj_interface.jl @@ -1,12 +1,3 @@ -#=MMI.@mlj_model mutable struct MaxnetBinaryClassifier <: MMI.Deterministic - features = "" - regularization_multiplier = 1.0 - regularization_function = default_regularization - weight_factor = 100. - backend = LassoBackend - kw... -end -=# mutable struct MaxnetBinaryClassifier <: MMI.Probabilistic features::Union{String, Vector{<:AbstractFeatureClass}} regularization_multiplier::Float64 @@ -14,6 +5,7 @@ mutable struct MaxnetBinaryClassifier <: MMI.Probabilistic weight_factor::Float64 backend::MaxnetBackend link::GLM.Link + clamp::Bool kw end @@ -21,15 +13,39 @@ function MaxnetBinaryClassifier(; features="", regularization_multiplier = 1.0, regularization_function = default_regularization, weight_factor = 100., backend = LassoBackend(), - link = CloglogLink(), + link = CloglogLink(), clamp = false, kw... ) + MaxnetBinaryClassifier( features, regularization_multiplier, regularization_function, - weight_factor, backend, link, kw + weight_factor, backend, link, clamp, kw ) end +""" + MaxnetBinaryClassifier + + A model type for fitting a maxnet model using `MLJ`. + + Use `MaxnetBinaryClassifier()` to create an instance with default parameters, or use keyword arguments to specify parameters. + + The `link` and `clamp` keywords are passed to `predict`. All other keywords are passed to `maxnet` when te model is fit. + See the documentation of [`maxnet`](@ref) for the parameters and their defaults. + + # Examples + ```jldoctest + using Maxnet, MLJBase + p_a, env = Maxnet.bradypus() + + mach = machine(MaxnetBinaryClassifier(features = "lqp"), env, categorical(p_a)) + fit!(mach) + yhat = MLJBase.predict(mach, env) + ``` + +""" +MaxnetBinaryClassifier + MMI.input_scitype(::Type{<:MaxnetBinaryClassifier}) = MMI.Table{<:Union{<:AbstractVector{<:Continuous}, <:AbstractVector{<:Multiclass}}} #{<:Union{<:Continuous <:Multiclass}} @@ -65,6 +81,6 @@ function MMI.fit(m::MaxnetBinaryClassifier, verbosity::Int, X, y) end function MMI.predict(m::MaxnetBinaryClassifier, (fitresult, decode), Xnew) - p = predict(fitresult, Xnew; link = m.link) + p = predict(fitresult, Xnew; link = m.link, clamp = m.clamp) MMI.UnivariateFinite(decode, [1 .- p, p]) end \ No newline at end of file diff --git a/src/predict.jl b/src/predict.jl index ca01e39..f92a7ae 100644 --- a/src/predict.jl +++ b/src/predict.jl @@ -1,11 +1,9 @@ -"""" +""" predict(m, x; link, clamp) - Use a fit maxnet model to predict - # Arguments -- `m`: a MaxnetModel generated by maxnet() -- `x`: a Tables.jl-compatible table of predictors. All columns that were used to fit `m` should be present in `x` +- `m`: a MaxnetModel as returned by `maxnet` +- `x`: a `Tables.jl`-compatible table of predictors. All columns that were used to fit `m` should be present in `x` # Keywords - `link`: the link function used. Defaults to CloglogLink(), which is the default on the Maxent Java appliation since version 4.3. @@ -15,10 +13,9 @@ - `clamp`: If `true`, values in `x` will be clamped to the range the model was trained on. Defaults to `false`. # Returns -A `Vector` of prediction values. +A `Vector` with the resulting predictions. """ - function predict(m::MaxnetModel, x; link = CloglogLink(), clamp = false) predictors = Tables.columntable(x) for k in keys(predictors)