diff --git a/Project.toml b/Project.toml index 034d148..aa30d6f 100644 --- a/Project.toml +++ b/Project.toml @@ -36,7 +36,8 @@ julia = "1.9" [extras] DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" +MLJTestInterface = "72560011-54dd-4dc2-94f3-c5de45b75ecd" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["DelimitedFiles", "MLJBase", "Test"] +test = ["DelimitedFiles", "MLJBase", "MLJTestInterface", "Test"] diff --git a/docs/Project.toml b/docs/Project.toml index 491d2dd..3f5ef3c 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,3 +1,4 @@ [deps] Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" Maxnet = "81f79f80-22f2-4e41-ab86-00c11cf0f26f" diff --git a/src/maxnet_function.jl b/src/maxnet_function.jl index 9e87636..fd414ac 100644 --- a/src/maxnet_function.jl +++ b/src/maxnet_function.jl @@ -49,6 +49,11 @@ function maxnet( n_knots::Int = 50, kw...) + if allequal(presences) + pa = first(presences) ? "presences" : "absences" + throw(ArgumentError("All data points are $pa. Maxnet will only work with at least some presences and some absences.")) + end + _maxnet( presences, predictors, diff --git a/src/mlj_interface.jl b/src/mlj_interface.jl index 878a21d..e081801 100644 --- a/src/mlj_interface.jl +++ b/src/mlj_interface.jl @@ -24,30 +24,6 @@ function MaxnetBinaryClassifier(; ) end -""" - MaxnetBinaryClassifier - - A model type for fitting a maxnet model using `MLJ`. - - Use `MaxnetBinaryClassifier()` to create an instance with default parameters, or use keyword arguments to specify parameters. - - The keywords `link`, and `clamp` are passed to [`Maxnet.predict`](@ref), while all other keywords are passed to [`maxnet`](@ref). - See the documentation of these functions for the meaning of these parameters and their defaults. - - # Example - ```jldoctest - using Maxnet, MLJBase - p_a, env = Maxnet.bradypus() - - mach = machine(MaxnetBinaryClassifier(features = "lqp"), env, categorical(p_a)) - fit!(mach) - yhat = MLJBase.predict(mach, env) - # output - ``` - -""" -MaxnetBinaryClassifier - MMI.metadata_pkg( MaxnetBinaryClassifier; name = "Maxnet", @@ -67,6 +43,26 @@ MMI.metadata_model( reports_feature_importances=false ) +""" +$(MMI.doc_header(MaxnetBinaryClassifier)) + +The keywords `link`, and `clamp` are passed to [`predict`](@ref), while all other keywords are passed to [`maxnet`](@ref). +See the documentation of these functions for the meaning of these parameters and their defaults. + +# Example +```@example +using MLJBase +p_a, env = Maxnet.bradypus() + +mach = machine(MaxnetBinaryClassifier(features = "lqp"), env, categorical(p_a), scitype_check_level = 0) +fit!(mach, verbosity = 0) +yhat = MLJBase.predict(mach, env) + +``` + +""" +MaxnetBinaryClassifier + function MMI.fit(m::MaxnetBinaryClassifier, verbosity::Int, X, y) # convert categorical to boolean y_boolean = Bool.(MMI.int(y) .- 1) diff --git a/test/runtests.jl b/test/runtests.jl index ed8e29a..e776b16 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,7 @@ -using Maxnet, Test, Statistics, CategoricalArrays +using Maxnet, Statistics, CategoricalArrays, MLJTestInterface +using Test +# read in Bradypus data p_a, env = Maxnet.bradypus() # Make the levels in ecoreg string to make sure that that works env = merge(env, (; ecoreg = recode(env.ecoreg, (l => string(l) for l in levels(env.ecoreg))...))) @@ -82,9 +84,18 @@ end m = maxnet(p_a, env; features = "lq", addsamplestobackground = false) @test m_w.entropy > m.entropy end -m = maxnet(p_a, env; features = "lq", addsamplestobackground = false) @testset "MLJ" begin + data = MLJTestInterface.make_binary() + failures, summary = MLJTestInterface.test( + [MaxnetBinaryClassifier], + data...; + mod=@__MODULE__, + verbosity=0, # bump to debug + throw=false, # set to true to debug + ) + @test isempty(failures) + using MLJBase mn = Maxnet.MaxnetBinaryClassifier