Skip to content

Commit

Permalink
improve documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
tiemvanderdeure committed Dec 21, 2023
1 parent fed6d43 commit d899a1c
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 25 deletions.
1 change: 1 addition & 0 deletions src/Maxnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ export IdentityLink, CloglogLink, LogitLink, LogLink # re-export relevant links
export LassoBackend, GLMNetBackend
export maxnet, predict
export LinearFeature, CategoricalFeature, QuadraticFeature, ProductFeature, ThresholdFeature, HingeFeature
export MaxnetBinaryClassifier

# Write your package code here.

Expand Down
11 changes: 11 additions & 0 deletions src/feature_classes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,17 @@ function features_from_string(s::AbstractString)
end

# Default features based on number of presences

"""
default_features(np)
Takes the number of presences `np` and returns a `Vector` of `AbstractFeatureClass`s that are used my maxent as default.
If `np` is less than ten, then only `LinearFeature` and `CategoricalFeature` are used.
If it is at least 10, then `QuadraticFeature` is additionally used.
If it is at least 15, then `HingeFeature` is additionally used.
If it is at least 80, then `ProductFeature` is additionally used.
"""
function default_features(np)
features = [LinearFeature(), CategoricalFeature()]
if np >= 10
Expand Down
21 changes: 15 additions & 6 deletions src/maxnet_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ end

"""
maxnet(
presences, predictors, [features];
presences, predictors;
[features],
regularization_multiplier, regularization_function,
addsamplestobackground, weight_factor, backend,
kw...
Expand All @@ -36,20 +37,28 @@ end
# Keywords
- `features`: Either:
- A `Vector` of `AbstractFeatureClass` type features; or
- A string where "l" = linear and categorical, "q" = quadratic, "p" = product, "t" = threshold, "h" = hinge; or
- Nothing, in which case the default features based on the number of presences are used
- A `String` where "l" = linear and categorical, "q" = quadratic, "p" = product, "t" = threshold, "h" = hinge (e.g. "lqh"); or
- The default, in which case the features are based on the number of presences are used. See [`default_features`](@ref)
- `regularization_multiplier`: A constant to adjust regularization, where a higher `regularization_multiplier` results in a higher penalization for features
- `regularization_function`: A function to compute a regularization for each feature. A default `regularization_function` is built in.
- `addsamplestobackground`: A boolean, where `true` adds the background samples to the predictors. Defaults to `true`.
- `n_knots`: the number of knots used for Threshold and Hinge features. Defaults to 50. Ignored if there are neither Threshold nor Hinge features
- `weight_factor`: A `Float64` to adjust the weight of the background samples. Defaults to 100.0.
- `backend`: Either `LassoBackend()` or `GLMNetBackend()`, to use Lasso.jl or GLMNet.jl fit the model.
- `weight_factor`: A `Float64` value to adjust the weight of the background samples. Defaults to 100.0.
- `backend`: Either `LassoBackend()` or `GLMNetBackend()`, to use either Lasso.jl or GLMNet.jl to fit the model.
Lasso.jl is written in pure julia, but can be slower with large model matrices (e.g. when hinge is enabled). Defaults to `LassoBackend`.
- `kw...`: Further arguments to be passed to Lasso.fit or GLMNet.glmnet
- `kw...`: Further arguments to be passed to `Lasso.fit` or `GLMNet.glmnet`
# Returns
- `model`: A model of type `MaxnetModel`
# Examples
```jldoctest
using Maxnet
p_a, env = Maxnet.bradypus()
bradypus_model = maxnet(p_a, env; features = "lq", backend = GLMNetBackend());
```
"""
function maxnet(
presences::BitVector, predictors;
Expand Down
40 changes: 28 additions & 12 deletions src/mlj_interface.jl
Original file line number Diff line number Diff line change
@@ -1,35 +1,51 @@
#=MMI.@mlj_model mutable struct MaxnetBinaryClassifier <: MMI.Deterministic
features = ""
regularization_multiplier = 1.0
regularization_function = default_regularization
weight_factor = 100.
backend = LassoBackend
kw...
end
=#
mutable struct MaxnetBinaryClassifier <: MMI.Probabilistic
features::Union{String, Vector{<:AbstractFeatureClass}}
regularization_multiplier::Float64
regularization_function
weight_factor::Float64
backend::MaxnetBackend
link::GLM.Link
clamp::Bool
kw
end

function MaxnetBinaryClassifier(;
features="",
regularization_multiplier = 1.0, regularization_function = default_regularization,
weight_factor = 100., backend = LassoBackend(),
link = CloglogLink(),
link = CloglogLink(), clamp = false,
kw...
)

MaxnetBinaryClassifier(
features, regularization_multiplier, regularization_function,
weight_factor, backend, link, kw
weight_factor, backend, link, clamp, kw
)
end

"""
MaxnetBinaryClassifier
A model type for fitting a maxnet model using `MLJ`.
Use `MaxnetBinaryClassifier()` to create an instance with default parameters, or use keyword arguments to specify parameters.
The `link` and `clamp` keywords are passed to `predict`. All other keywords are passed to `maxnet` when te model is fit.
See the documentation of [`maxnet`](@ref) for the parameters and their defaults.
# Examples
```jldoctest
using Maxnet, MLJBase
p_a, env = Maxnet.bradypus()
mach = machine(MaxnetBinaryClassifier(features = "lqp"), env, categorical(p_a))
fit!(mach)
yhat = MLJBase.predict(mach, env)
```
"""
MaxnetBinaryClassifier

MMI.input_scitype(::Type{<:MaxnetBinaryClassifier}) =
MMI.Table{<:Union{<:AbstractVector{<:Continuous}, <:AbstractVector{<:Multiclass}}} #{<:Union{<:Continuous <:Multiclass}}

Expand Down Expand Up @@ -65,6 +81,6 @@ function MMI.fit(m::MaxnetBinaryClassifier, verbosity::Int, X, y)
end

function MMI.predict(m::MaxnetBinaryClassifier, (fitresult, decode), Xnew)
p = predict(fitresult, Xnew; link = m.link)
p = predict(fitresult, Xnew; link = m.link, clamp = m.clamp)
MMI.UnivariateFinite(decode, [1 .- p, p])
end
11 changes: 4 additions & 7 deletions src/predict.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
""""
"""
predict(m, x; link, clamp)
Use a fit maxnet model to predict
# Arguments
- `m`: a MaxnetModel generated by maxnet()
- `x`: a Tables.jl-compatible table of predictors. All columns that were used to fit `m` should be present in `x`
- `m`: a MaxnetModel as returned by `maxnet`
- `x`: a `Tables.jl`-compatible table of predictors. All columns that were used to fit `m` should be present in `x`
# Keywords
- `link`: the link function used. Defaults to CloglogLink(), which is the default on the Maxent Java appliation since version 4.3.
Expand All @@ -15,10 +13,9 @@
- `clamp`: If `true`, values in `x` will be clamped to the range the model was trained on. Defaults to `false`.
# Returns
A `Vector` of prediction values.
A `Vector` with the resulting predictions.
"""

function predict(m::MaxnetModel, x; link = CloglogLink(), clamp = false)
predictors = Tables.columntable(x)
for k in keys(predictors)
Expand Down

0 comments on commit d899a1c

Please sign in to comment.