Skip to content

Commit

Permalink
fully implement MLJ (#22)
Browse files Browse the repository at this point in the history
* add mlj docstring

* test with MLJTestInterface

* throw a helpful error if input data only has one class

* mljtestinterface is not a dep (oops)

* move allequal error to main function

* fix allequal error

* fix tests

* add MLJBase as docs dep

* fix mlj doctest

* attempt fix of multiclass printing

* use @example instead of jldoctest

* test for no failures in mlj interface test
  • Loading branch information
tiemvanderdeure authored Dec 3, 2024
1 parent 73daa4d commit 1742c96
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 27 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ julia = "1.9"
[extras]
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJTestInterface = "72560011-54dd-4dc2-94f3-c5de45b75ecd"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["DelimitedFiles", "MLJBase", "Test"]
test = ["DelimitedFiles", "MLJBase", "MLJTestInterface", "Test"]
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
[deps]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
Maxnet = "81f79f80-22f2-4e41-ab86-00c11cf0f26f"
5 changes: 5 additions & 0 deletions src/maxnet_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ function maxnet(
n_knots::Int = 50,
kw...)

if allequal(presences)
pa = first(presences) ? "presences" : "absences"
throw(ArgumentError("All data points are $pa. Maxnet will only work with at least some presences and some absences."))
end

_maxnet(
presences,
predictors,
Expand Down
44 changes: 20 additions & 24 deletions src/mlj_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,30 +24,6 @@ function MaxnetBinaryClassifier(;
)
end

"""
MaxnetBinaryClassifier
A model type for fitting a maxnet model using `MLJ`.
Use `MaxnetBinaryClassifier()` to create an instance with default parameters, or use keyword arguments to specify parameters.
The keywords `link`, and `clamp` are passed to [`Maxnet.predict`](@ref), while all other keywords are passed to [`maxnet`](@ref).
See the documentation of these functions for the meaning of these parameters and their defaults.
# Example
```jldoctest
using Maxnet, MLJBase
p_a, env = Maxnet.bradypus()
mach = machine(MaxnetBinaryClassifier(features = "lqp"), env, categorical(p_a))
fit!(mach)
yhat = MLJBase.predict(mach, env)
# output
```
"""
MaxnetBinaryClassifier

MMI.metadata_pkg(
MaxnetBinaryClassifier;
name = "Maxnet",
Expand All @@ -67,6 +43,26 @@ MMI.metadata_model(
reports_feature_importances=false
)

"""
$(MMI.doc_header(MaxnetBinaryClassifier))
The keywords `link`, and `clamp` are passed to [`predict`](@ref), while all other keywords are passed to [`maxnet`](@ref).
See the documentation of these functions for the meaning of these parameters and their defaults.
# Example
```@example
using MLJBase
p_a, env = Maxnet.bradypus()
mach = machine(MaxnetBinaryClassifier(features = "lqp"), env, categorical(p_a), scitype_check_level = 0)
fit!(mach, verbosity = 0)
yhat = MLJBase.predict(mach, env)
```
"""
MaxnetBinaryClassifier

function MMI.fit(m::MaxnetBinaryClassifier, verbosity::Int, X, y)
# convert categorical to boolean
y_boolean = Bool.(MMI.int(y) .- 1)
Expand Down
15 changes: 13 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using Maxnet, Test, Statistics, CategoricalArrays
using Maxnet, Statistics, CategoricalArrays, MLJTestInterface
using Test

# read in Bradypus data
p_a, env = Maxnet.bradypus()
# Make the levels in ecoreg string to make sure that that works
env = merge(env, (; ecoreg = recode(env.ecoreg, (l => string(l) for l in levels(env.ecoreg))...)))
Expand Down Expand Up @@ -82,9 +84,18 @@ end
m = maxnet(p_a, env; features = "lq", addsamplestobackground = false)
@test m_w.entropy > m.entropy
end
m = maxnet(p_a, env; features = "lq", addsamplestobackground = false)

@testset "MLJ" begin
data = MLJTestInterface.make_binary()
failures, summary = MLJTestInterface.test(
[MaxnetBinaryClassifier],
data...;
mod=@__MODULE__,
verbosity=0, # bump to debug
throw=false, # set to true to debug
)
@test isempty(failures)

using MLJBase
mn = Maxnet.MaxnetBinaryClassifier

Expand Down

0 comments on commit 1742c96

Please sign in to comment.