diff --git a/src/Hyperopt.jl b/src/Hyperopt.jl index 8564cd4..04c32aa 100644 --- a/src/Hyperopt.jl +++ b/src/Hyperopt.jl @@ -6,6 +6,7 @@ export RandomSampler, LHSampler, CLHSampler, hyperband, Hyperband, hyperoptim, B using Base.Threads: threadid, nthreads using LinearAlgebra, Statistics, Random using ProgressMeter +using ComponentArrays using MacroTools using MacroTools: postwalk, prewalk using RecipesBase @@ -37,11 +38,23 @@ Base.@kwdef mutable struct Hyperoptimizer{S<:Sampler, F} objective::F = nothing end +function namedtuple_params(ho, i) + vals = ho.history[i] + ComponentVector((; zip(ho.params, vals)...)) +end + function Base.getproperty(ho::Hyperoptimizer, s::Symbol) s === :minimum && (return isempty(ho.results) ? NaN : minimum(replace(ho.results, NaN => Inf))) - s === :minimizer && (return isempty(ho.results) ? [] : ho.history[argmin(replace(ho.results, NaN => Inf))]) + # s === :minimizer && (return isempty(ho.results) ? [] : ho.history[argmin(replace(ho.results, NaN => Inf))]) s === :maximum && (return isempty(ho.results) ? NaN : maximum(replace(ho.results, NaN => -Inf))) - s === :maximizer && (return isempty(ho.results) ? [] : ho.history[argmax(replace(ho.results, NaN => -Inf))]) + # s === :maximizer && (return isempty(ho.results) ? [] : ho.history[argmax(replace(ho.results, NaN => -Inf))]) + if s === :minimizer + return isempty(ho.results) ? [] : namedtuple_params(ho, argmin(replace(ho.results, NaN => Inf))) + end + s === :maximum && (return isempty(ho.results) ? NaN : maximum(replace(ho.results, NaN => Inf))) + if s === :maximizer + return isempty(ho.results) ? [] : namedtuple_params(ho, argmax(replace(ho.results, NaN => Inf))) + end return getfield(ho,s) end @@ -356,7 +369,7 @@ function warn_on_boundary(ho, sense = :min) (m[i],) end end - for i in eachindex(m) + for i in 1:length(m) c = unique(ho.candidates[i]) if m[i] ∈ extremas[i] && length(c) > 3 println("Parameter $(ho.params[i]) obtained its optimum on an extremum of the sampled region: $(m[i])") diff --git a/src/optim.jl b/src/optim.jl index fa1aa20..bd296c3 100644 --- a/src/optim.jl +++ b/src/optim.jl @@ -30,4 +30,50 @@ function hyperoptim(f, candidates, algorithm = Optim.NelderMead(), opts = Optim. ho = hyperband(fun, candidates; R, η, inner, threads) ho +end + +function multistart(f, candidates; N, algorithm = Optim.NelderMead(), opts = Optim.Options(), inner = RandomSampler(), threads=false) + ho = Hyperoptimizer(; + iterations = N, + params = [Symbol("$i") for i in eachindex(candidates)], + candidates, + history = Vector{Any}(undef, N), + results = Vector{Any}(undef, N), + sampler = inner, + objective = f, + ) + if inner isa Union{LHSampler,CLHSampler} + ho.iterations = length(candidates[1]) + init!(inner, ho) + end + sem = Base.Semaphore(threads) + try + @sync for i = 1:N + pars = ho.sampler(ho, i) + # nt = (; Pair.((:i, ho.params...), (state, samples...))...) + # pars = [Base.tail(nt)...] # the first element is i + + if threads >= 2 + Base.acquire(sem) + Threads.@spawn begin + res = Optim.optimize(f, pars, algorithm, opts) + ho.history[i] = res.minimizer + ho.results[i] = res.minimum + Base.release(sem) + end + + else + res = Optim.optimize(f, pars, algorithm, opts) + ho.history[i] = res.minimizer + ho.results[i] = res.minimum + end + end + catch e + if e isa InterruptException + @info "Aborting hyperoptimization" + else + rethrow() + end + end + ho end \ No newline at end of file