Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

question on how to correctly implement an interface to a probabilistic classifier. #211

Closed
pasq-cat opened this issue Sep 19, 2024 · 2 comments

Comments

@pasq-cat
Copy link

Hi, i was trying to implement an interface between laplaceredux and mlj but i am facing an issue with implementing the probabilistic classifier model. In particular, i have not fully understood how to correctly use UnivariateFinite.

I have imported the packages

using Flux
using Random
using Tables
using LinearAlgebra
using LaplaceRedux
using MLJBase
import MLJModelInterface as MMI
using Distributions: Normal

created the model

MLJBase.@mlj_model mutable struct LaplaceClassifier <: MLJFlux.MLJFluxProbabilistic

    flux_model::Flux.Chain = nothing
    flux_loss = Flux.Losses.logitcrossentropy
    epochs::Integer = 1000::(_ > 0)
    batch_size::Integer= 32::(_ > 0)
    subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork))
    subnetwork_indices = nothing
    hessian_structure::Union{HessianStructure,Symbol,String} =
        :full::(_ in (:full, :diagonal))
    backend::Symbol = :GGN::(_ in (:GGN, :EmpiricalFisher))
    σ::Float64 = 1.0
    μ₀::Float64 = 0.0
    P₀::Union{AbstractMatrix,UniformScaling,Nothing} = nothing
    #ret_distr::Bool = false::(_ in (true, false))
    fit_prior_nsteps::Int = 100::(_ > 0)
    link_approx::Symbol = :probit::(_ in (:probit, :plugin))
end

written a fit function

function MMI.fit(m::LaplaceClassifier, verbosity, X, y)
    X = MLJBase.matrix(X) |> permutedims
    decode = y[1]
    y_plain   = MLJBase.int(y) .- 1 
    y_onehot = Flux.onehotbatch(y_plain,  unique(y_plain) )
    data_loader = Flux.DataLoader((X,y_onehot), batchsize=m.batch_size)
    opt_state = Flux.setup(Adam(), m.flux_model)

    for epoch in 1:m.epochs
        Flux.train!(m.flux_model,data_loader, opt_state) do model, X, y_onehot
            m.flux_loss(model(X), y_onehot)
        
        end
      end


    la = LaplaceRedux.Laplace(
        m.flux_model;
        likelihood=:classification,
        subset_of_weights=m.subset_of_weights,
        subnetwork_indices=m.subnetwork_indices,
        hessian_structure=m.hessian_structure,
        backend=m.backend,
        σ=m.σ,
        μ₀=m.μ₀,
        P₀=m.P₀,
    )

    # fit the Laplace model:
    LaplaceRedux.fit!(la, data_loader )
    optimize_prior!(la; verbose= false, n_steps=m.fit_prior_nsteps)

    report = (status="success", message="Model fitted successfully")
    cache     = nothing
    return ((la,decode), cache, report)
end

and the predict function

function MMI.predict(m::LaplaceClassifier, (fitresult, decode), Xnew)
    la = fitresult
    Xnew = MLJBase.matrix(Xnew) |> permutedims
    predictions = LaplaceRedux.predict(
        la,
        Xnew;
        link_approx=m.link_approx,
        ret_distr=false)
    return [MLJBase.UnivariateFinite(MLJBase.classes(decode), prediction, pool= decode, augment=true) for prediction in predictions]
end

but when i run predict i get the error
Warning: Ignoring value of pool as the specified support defines one already.
and the error is just the last line with UnivariateFinite.

@ablaom
Copy link
Member

ablaom commented Sep 19, 2024

I think you can just drop pool=decode as MLJBase.classes(decode) is a categorical vector, which therefore already includes the pool. You only need to specify a pool if the first argument of UnivariateFinite is a raw vector (not a categorical vector).

(Elements of the first argument not in the pool still get assigned a probability, namely zero.)

Let me know if that does not work.

@pasq-cat
Copy link
Author

pasq-cat commented Sep 19, 2024

yes it works, there were also a couple of other errors in the predict that i have found once solved the first error.
thank you ablaom

i changed it in

function MMI.predict(m::LaplaceClassifier, (fitresult, decode), Xnew)
    la = fitresult
    Xnew = MLJBase.matrix(Xnew) |> permutedims
    predictions = LaplaceRedux.predict(
        la,
        Xnew;
        link_approx=m.link_approx,
        ret_distr=false) |>permutedims

        


    return MLJBase.UnivariateFinite(MLJBase.classes(decode), predictions)
end

since it is written that i shouldn't make the univariatefinite arrays one at time.

@ablaom ablaom closed this as completed Sep 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants