Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework high level interfaces #392

Merged
merged 107 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
107 commits
Select commit Hold shift + click to select a range
2484fae
very first sketch of a new function struct type to wrap them and make…
kellertuer Jun 5, 2024
ab09993
Revert "very first sketch of a new function struct type to wrap them …
kellertuer Jun 5, 2024
9aae785
trigger all ambiguity errors.
kellertuer Jun 5, 2024
232e7b3
Move dispatch on p into subfunctions, that existed before anyways.
kellertuer Jun 6, 2024
91d1038
rework the safeguards to dispatch internally.
kellertuer Jun 6, 2024
da222a0
rework the safeguards to dispatch internally.
kellertuer Jun 6, 2024
b1777d2
Merge branch 'kellertuer/rework-high-level-interfaces' of github.com:…
kellertuer Jun 6, 2024
c85c98b
Fix a few bugs.
kellertuer Jun 6, 2024
0691ae3
Merge branch 'master' into kellertuer/rework-high-level-interfaces
kellertuer Jun 6, 2024
6b4291f
fixes all ambiguities but the ones with TCG.
kellertuer Jun 6, 2024
c8944e1
runs formatter.
kellertuer Jun 6, 2024
9b525ae
bump version to 0.5.0, describe what is beaking, simplify code and te…
kellertuer Jun 6, 2024
912d47e
remove deprecated definitions.
kellertuer Jun 6, 2024
a56f6a9
runs formatter.
kellertuer Jun 6, 2024
26fe778
Fix documenter. Increase bound on ManoptExamples to use the extension.
kellertuer Jun 8, 2024
25aca12
Loosen constraints on ManoptExamples since it works with 0.1.6 until …
kellertuer Jun 9, 2024
dcc2d74
Merge branch 'master' into kellertuer/rework-high-level-interfaces
kellertuer Jun 9, 2024
56e3504
Update test/test_aqua.jl
kellertuer Jun 9, 2024
a7b9eb2
raise dependency.
kellertuer Jun 9, 2024
13d99ec
Fix a typo in reducing Aqua.
kellertuer Jun 9, 2024
7b382d9
remove all deprecated parameters.
kellertuer Jun 12, 2024
7d8cc46
runs formatter on the correct folder.
kellertuer Jun 12, 2024
0a91e10
Fix a stopping criterion.
kellertuer Jun 13, 2024
9482223
Consistently wrap numbers in arrays for points/tangent vectors intern…
kellertuer Jun 13, 2024
0419908
Merge branch 'master' into kellertuer/rework-high-level-interfaces
kellertuer Jun 13, 2024
58b7fad
Merge branch 'master' into kellertuer/rework-high-level-interfaces
kellertuer Jul 5, 2024
d6fa02b
Fix typos to correctly handle constraints
hajg-ijk Aug 5, 2024
e9e827d
Merge branch 'master' into kellertuer/rework-high-level-interfaces
kellertuer Aug 5, 2024
b2455c2
Fix two places, where I was a bit too fast.
kellertuer Aug 5, 2024
bf7451c
Maybe like this?
kellertuer Aug 5, 2024
509e6a3
Merge branch 'master' into kellertuer/rework-high-level-interfaces
kellertuer Aug 11, 2024
7ee5ee9
Reiterate on deprecated things (reintroduced by merging).
kellertuer Aug 11, 2024
f1c8fd6
rename set_manopt_parameter! to set_parameter and get_manopt_paramete…
kellertuer Aug 12, 2024
fb5a5c4
unify stabilisation through projection keyword.
kellertuer Aug 12, 2024
92ad1f9
remove `update_stopping_criterion` to update values in the stopping c…
kellertuer Aug 12, 2024
4ea449f
Fix docs.
kellertuer Aug 12, 2024
6d5c422
Fix quarto env.
kellertuer Aug 12, 2024
74fba25
Maybe tweak tutorials to still/again run?
kellertuer Aug 12, 2024
6e3dcb5
unify state constructors of ALM and ARC
kellertuer Aug 13, 2024
88445e5
Adapt EPM.
kellertuer Aug 13, 2024
6e95c7c
Finish DoC and DCPPA
kellertuer Aug 13, 2024
c3d20ec
Unify how p and X are passed to states overall in all states.
kellertuer Aug 13, 2024
edd3481
Fix all tests.
kellertuer Aug 13, 2024
f7363c6
Fix a typo in the quarto notebook.
kellertuer Aug 13, 2024
3147fb4
Sketch a first idea of a factory.
kellertuer Aug 14, 2024
3ea9052
Fix a few typos in the docs.
kellertuer Aug 14, 2024
033f9d7
Check and fix docs.
kellertuer Aug 14, 2024
86c6742
Work on test coverage for new state constructors.
kellertuer Aug 14, 2024
08a0511
Remove some deprecated old code and increase code cov
kellertuer Aug 14, 2024
649bc7e
increase test coverage.
kellertuer Aug 14, 2024
bb574c2
Fix scaling parameter in quasi newton. This resolvs #382
kellertuer Aug 16, 2024
b53f783
Start changing the docs snippets to using a dictionary approach inste…
kellertuer Aug 17, 2024
1e5f922
Twiddling a bit with the new factory.
kellertuer Aug 17, 2024
70045da
adds a note.
kellertuer Aug 17, 2024
70ab110
Apply suggestions from code review
kellertuer Aug 19, 2024
08a868c
Generalise the factory to more than just Direction Updates; address p…
kellertuer Aug 19, 2024
0670728
Merge branch 'kellertuer/rework-high-level-interfaces' of github.com:…
kellertuer Aug 19, 2024
1bf5570
Start meta fields in the glossary. Format can probably be improved
kellertuer Aug 19, 2024
574e158
Work a bit towards the glossary idea.
kellertuer Aug 19, 2024
085c976
order alphabetically.
kellertuer Aug 19, 2024
5430abb
Start introducing a proper glossary,
kellertuer Aug 20, 2024
585f71a
First base for a glossary.
kellertuer Aug 20, 2024
dec40cd
Trying the new glossary in a slightly larger scale, it seems to work …
kellertuer Aug 20, 2024
822856c
runs formatter.
kellertuer Aug 20, 2024
80c2aa0
A bit of further glossary work.
kellertuer Aug 20, 2024
b4e1a6a
Before moving to the factory lets first get the glossary established …
kellertuer Aug 21, 2024
2cda5e0
Replace further arguments.
kellertuer Aug 21, 2024
4f1fe6d
Finishing all argument glossary entries.
kellertuer Aug 21, 2024
bfe4d37
Simplify first Keyword which is now really much nicer to write.
kellertuer Aug 21, 2024
cb8c9a4
Codecov updates.
kellertuer Aug 21, 2024
bdcf267
acidentially removed a string short.
kellertuer Aug 21, 2024
bc0a120
Unify most retractions, inverse retraction and vector transport keywo…
kellertuer Aug 22, 2024
3ce1ccf
Refactor docs further.
kellertuer Aug 22, 2024
1884103
refactor docs for sub problem and state keywords
kellertuer Aug 22, 2024
212fa4e
📚Finish the glossary work
kellertuer Aug 22, 2024
e79295b
Refactor types and docs for ConjugateDescentCoefficient and DaiYuanCo…
kellertuer Aug 22, 2024
a5f9a3c
Move 3 more gradient rules over to the factory pattern.
kellertuer Aug 22, 2024
7c8a382
Finish. the remaining rules.
kellertuer Aug 23, 2024
79382e7
Add test coverage.
kellertuer Aug 23, 2024
5563c2d
Adds a final test for today.
kellertuer Aug 23, 2024
5598bab
Fix two typos.
kellertuer Aug 23, 2024
02c4689
And another final final test.
kellertuer Aug 23, 2024
d42990f
Refactor Armijo line search to use a factory.
kellertuer Aug 24, 2024
5b3956b
Finish Polyak docs.
kellertuer Aug 24, 2024
0e488b9
Fix docs.
kellertuer Aug 24, 2024
c387209
Fix/simplify tutorial.
kellertuer Aug 24, 2024
f81f048
Fix a typo, thanks to tutorial mode, there were important warning.
kellertuer Aug 24, 2024
04f56a5
Fix two typos in the docs.
kellertuer Aug 24, 2024
b23e549
Code coverage.
kellertuer Aug 24, 2024
cec7454
Trying to bump to Manifolds 0.10.
kellertuer Aug 25, 2024
8655671
Maybe it works now?
kellertuer Aug 25, 2024
9881f73
Move product manifold on alternating gradient to a new extension.
kellertuer Aug 25, 2024
5db1943
We are eneting dependency fun Part 3.
kellertuer Aug 25, 2024
def9c97
First time ever I am running int dependency hell with Julia. Somethin…
kellertuer Aug 25, 2024
508ad58
Maybe now?
kellertuer Aug 25, 2024
5f6f929
First step in wolfe powell.
kellertuer Aug 25, 2024
0dfb2e6
Minor fixes.
kellertuer Aug 25, 2024
6836dc8
fnumps.
kellertuer Aug 25, 2024
fce828f
Fnord.
kellertuer Aug 25, 2024
6e5cef9
Finish NM rework.
kellertuer Aug 26, 2024
108566d
Finish AWN.
kellertuer Aug 26, 2024
868c02d
Fix docs to render cases.
kellertuer Aug 26, 2024
9f48ba7
A few final tweaks to the docs.
kellertuer Aug 26, 2024
318bca3
Unify signatures, fix a few tex typos and improve english on the fact…
kellertuer Aug 27, 2024
0bf2244
runs formatter.
kellertuer Aug 27, 2024
2242159
Small fix in Polyak.
kellertuer Aug 27, 2024
949457a
Update .github/workflows/ci.yml
kellertuer Aug 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,9 @@ All notable Changes to the Julia package `Manopt.jl` will be documented in this
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [0.4.64] unreleased

## [0.4.63] June 4, 2024

### Changed

* Fixed a bug that Lanczos produced NaNs when started exactly in a minimizer, since we divide by the gradient norm.
* Fix ambiguities that occurred due to point being nonmutating.

## [0.4.63] May 11, 2024

Expand Down
11 changes: 2 additions & 9 deletions src/plans/gradient_plan.jl
Original file line number Diff line number Diff line change
Expand Up @@ -221,13 +221,6 @@ function get_subgradient!(
return get_gradient!(M, X, agmo, p)
end

function _to_mutating_gradient(grad_f, evaluation::AllocatingEvaluation)
return grad_f_(M, p) = [grad_f(M, p[])]
end
function _to_mutating_gradient(grad_f, evaluation::InplaceEvaluation)
return grad_f_(M, X, p) = (X .= [grad_f(M, p[])])
end

@doc raw"""
get_gradient(agst::AbstractGradientSolverState)

Expand Down Expand Up @@ -458,7 +451,6 @@ mutable struct Nesterov{P,R<:Real} <: DirectionUpdateRule
shrinkage::Function
inverse_retraction_method::AbstractInverseRetractionMethod
end
Nesterov(M::AbstractManifold, p::Number; kwargs...) = Nesterov(M, [p]; kwargs...)
function Nesterov(
M::AbstractManifold,
p::P;
Expand All @@ -469,7 +461,8 @@ function Nesterov(
M, P
),
) where {P,T}
return Nesterov{P,T}(γ, μ, copy(M, p), shrinkage, inverse_retraction_method)
p_ = _ensure_mutating_variable(p)
return Nesterov{typeof(p_),T}(γ, μ, copy(M, p_), shrinkage, inverse_retraction_method)
end
function (n::Nesterov)(mp::AbstractManoptProblem, s::AbstractGradientSolverState, i)
M = get_manifold(mp)
Expand Down
48 changes: 48 additions & 0 deletions src/plans/objective.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,54 @@ function ReturnManifoldObjective(
return ReturnManifoldObjective{E,O2,O1}(o)
end

function _ensure_mutating_cost(cost, q::Number)
return isnothing(cost) ? cost : (M, p) -> cost(M, p[])
end
function _ensure_mutating_cost(cost, p)
return cost
end

function _ensure_mutating_gradient(grad_f, p, evaluation::AbstractEvaluationType)
return grad_f
end
function _ensure_mutating_gradient(grad_f, q::Number, evaluation::AllocatingEvaluation)
return isnothing(grad_f) ? grad_f : (M, p) -> [grad_f(M, p[])]
end
function _ensure_mutating_gradient(grad_f, q::Number, evaluation::InplaceEvaluation)
return isnothing(grad_f) ? grad_f : (M, X, p) -> (X .= [grad_f(M, p[])])
end

function _ensure_mutating_hessian(hess_f, p, evaluation::AbstractEvaluationType)
return hess_f
end
function _ensure_mutating_hessian(hess_f, q::Number, evaluation::AllocatingEvaluation)
return isnothing(hess_f) ? hess_f : (M, p, X) -> [hess_f(M, p[], X[])]
end
function _ensure_mutating_hessian(hess_f, q::Number, evaluation::InplaceEvaluation)
return isnothing(hess_f) ? hess_f : (M, Y, p, X) -> (Y .= [hess_f(M, p[], X[])])
end

function _ensure_mutating_prox(prox_f, p, evaluation::AbstractEvaluationType)
return prox_f
end
function _ensure_mutating_prox(prox_f, q::Number, evaluation::AllocatingEvaluation)
return isnothing(prox_f) ? prox_f : (M, λ, p) -> [prox_f(M, λ, p[])]
end
function _ensure_mutating_prox(prox_f, q::Number, evaluation::InplaceEvaluation)
return isnothing(prox_f) ? prox_f : (M, q, λ, p) -> (q .= [prox_f(M, λ, p[])])
end

_ensure_mutating_variable(p) = p
_ensure_mutating_variable(q::Number) = [q]
"""
_ensure_matching_output(p, q)
_ensure_matching_output(e, q, s)

If p is a number and q is a vector, return q[] (a number) otherwise q
"""
_ensure_matching_output(::T, q::Vector{T}) where {T} = length(q) == 1 ? q[] : q
_ensure_matching_output(p, q) = q

"""
dispatch_objective_decorator(o::AbstractManoptSolverState)

Expand Down
29 changes: 7 additions & 22 deletions src/solvers/DouglasRachford.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,30 +189,15 @@ function DouglasRachford(
parallel=0,
kwargs...,
) where {TF}
N, f_, (prox1, prox2), parallel_, p0 = parallel_to_alternating_DR(
M, f, proxes_f, p, parallel, evaluation
p_ = _ensure_mutating_variable(p)
f_ = _ensure_mutating_cost(f, p)
proxes_f_ = [_ensure_mutating_prox(prox_f, p, evaluation) for prox_f in proxes_f]
N, f__, (prox1, prox2), parallel_, q = parallel_to_alternating_DR(
M, f_, proxes_f_, p_, parallel, evaluation
)
mpo = ManifoldProximalMapObjective(f_, (prox1, prox2); evaluation=evaluation)
return DouglasRachford(N, mpo, p0; evaluation=evaluation, parallel=parallel_, kwargs...)
end
function DouglasRachford(
M::AbstractManifold,
f::TF,
proxes_f::Vector{<:Any},
p::Number;
evaluation::AbstractEvaluationType=AllocatingEvaluation(),
kwargs...,
) where {TF}
q = [p]
f_(M, p) = f(M, p[])
if evaluation isa AllocatingEvaluation
proxes_f_ = [(M, λ, p) -> [pf(M, λ, p[])] for pf in proxes_f]
else
proxes_f_ = [(M, q, λ, p) -> (q .= [pf(M, λ, p[])]) for pf in proxes_f]
end
rs = DouglasRachford(M, f_, proxes_f_, q; evaluation=evaluation, kwargs...)
#return just a number if the return type is the same as the type of q
return (typeof(q) == typeof(rs)) ? rs[] : rs
rs = DouglasRachford(N, mpo, q; evaluation=evaluation, parallel=parallel_, kwargs...)
return _ensure_matching_output(p, rs)
end
function DouglasRachford(
M::AbstractManifold, mpo::O, p; kwargs...
Expand Down
24 changes: 6 additions & 18 deletions src/solvers/FrankWolfe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,24 +198,12 @@ function Frank_Wolfe_method(
evaluation::AbstractEvaluationType=AllocatingEvaluation(),
kwargs...,
)
mgo = ManifoldGradientObjective(f, grad_f; evaluation=evaluation)
return Frank_Wolfe_method(M, mgo, p; evaluation=evaluation, kwargs...)
end
function Frank_Wolfe_method(
M::AbstractManifold,
f,
grad_f,
p::Number;
evaluation::AbstractEvaluationType=AllocatingEvaluation(),
kwargs...,
)
# redefine initial point
q = [p]
f_(M, p) = f(M, p[])
grad_f_ = _to_mutating_gradient(grad_f, evaluation)
rs = Frank_Wolfe_method(M, f_, grad_f_, q; evaluation=evaluation, kwargs...)
#return just a number if the return type is the same as the type of q
return (typeof(q) == typeof(rs)) ? rs[] : rs
p_ = _ensure_mutating_variable(p)
f_ = _ensure_mutating_cost(f, p)
grad_f_ = _ensure_mutating_gradient(grad_f, p, evaluation)
mgo = ManifoldGradientObjective(f_, grad_f_; evaluation=evaluation)
rs = Frank_Wolfe_method(M, mgo, p_; evaluation=evaluation, kwargs...)
return _ensure_matching_output(p, rs)
end
function Frank_Wolfe_method(
M::AbstractManifold, mgo::O, p; kwargs...
Expand Down
8 changes: 3 additions & 5 deletions src/solvers/NelderMead.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,19 @@ end
function NelderMeadSimplex(M::AbstractManifold)
return NelderMeadSimplex([rand(M) for i in 1:(manifold_dimension(M) + 1)])
end
function NelderMeadSimplex(M::AbstractManifold, p::Number, B::AbstractBasis; kwargs...)
return NelderMeadSimplex(M, [p], B; kwargs...)
end
function NelderMeadSimplex(
M::AbstractManifold,
p,
B::AbstractBasis=DefaultOrthonormalBasis();
a::Real=0.025,
retraction_method::AbstractRetractionMethod=default_retraction_method(M, typeof(p)),
)
p_ = _ensure_mutating_variable(p)
M_dim = manifold_dimension(M)
vecs = [
get_vector(M, p, [ifelse(i == j, a, zero(a)) for i in 1:M_dim], B) for j in 0:M_dim
get_vector(M, p_, [ifelse(i == j, a, zero(a)) for i in 1:M_dim], B) for j in 0:M_dim
]
pts = map(X -> retract(M, p, X, retraction_method), vecs)
pts = map(X -> retract(M, p_, X, retraction_method), vecs)
return NelderMeadSimplex(pts)
end

Expand Down
33 changes: 7 additions & 26 deletions src/solvers/adaptive_regularization_with_cubics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -257,32 +257,13 @@ function adaptive_regularization_with_cubics(
evaluation::AbstractEvaluationType=AllocatingEvaluation(),
kwargs...,
) where {TF,TDF,THF}
mho = ManifoldHessianObjective(f, grad_f, Hess_f; evaluation=evaluation)
return adaptive_regularization_with_cubics(M, mho, p; evaluation=evaluation, kwargs...)
end
function adaptive_regularization_with_cubics(
M::AbstractManifold,
f::TF,
grad_f::TDF,
Hess_f::THF,
p::Number;
evaluation::AbstractEvaluationType=AllocatingEvaluation(),
kwargs...,
) where {TF,TDF,THF}
q = [p]
f_(M, p) = f(M, p[])
Hess_f_ = Hess_f
if evaluation isa AllocatingEvaluation
grad_f_ = (M, p) -> [grad_f(M, p[])]
Hess_f_ = (M, p, X) -> [Hess_f(M, p[], X[])]
else
grad_f_ = (M, X, p) -> (X .= [grad_f(M, p[])])
Hess_f_ = (M, Y, p, X) -> (Y .= [Hess_f(M, p[], X[])])
end
rs = adaptive_regularization_with_cubics(
M, f_, grad_f_, Hess_f_, q; evaluation=evaluation, kwargs...
)
return (typeof(q) == typeof(rs)) ? rs[] : rs
p_ = _ensure_mutating_variable(p)
f_ = _ensure_mutating_cost(f, p)
grad_f_ = _ensure_mutating_gradient(grad_f, p, evaluation)
Hess_f_ = _ensure_mutating_hessian(Hess_f, p, evaluation)
mho = ManifoldHessianObjective(f_, grad_f_, Hess_f_; evaluation=evaluation)
rs = adaptive_regularization_with_cubics(M, mho, p_; evaluation=evaluation, kwargs...)
return _ensure_matching_output(p, rs)
end
function adaptive_regularization_with_cubics(M::AbstractManifold, f, grad_f; kwargs...)
return adaptive_regularization_with_cubics(M, f, grad_f, rand(M); kwargs...)
Expand Down
42 changes: 13 additions & 29 deletions src/solvers/augmented_Lagrangian_method.jl
Original file line number Diff line number Diff line change
Expand Up @@ -267,43 +267,27 @@ function augmented_Lagrangian_method(
grad_h=nothing,
kwargs...,
) where {TF,TGF}
q = copy(M, p)
p_ = _ensure_mutating_variable(p)
f_ = _ensure_mutating_cost(f, p)
grad_f_ = _ensure_mutating_gradient(grad_f, p, evaluation)
g_ = _ensure_mutating_cost(g, p)
grad_g_ = _ensure_mutating_gradient(grad_g, p, evaluation)
h_ = _ensure_mutating_cost(h, p)
grad_h_ = _ensure_mutating_gradient(grad_h, p, evaluation)

cmo = ConstrainedManifoldObjective(
f, grad_f, g, grad_g, h, grad_h; evaluation=evaluation
f_, grad_f_, g_, grad_g_, h_, grad_h_; evaluation=evaluation
)
return augmented_Lagrangian_method!(M, cmo, q; evaluation=evaluation, kwargs...)
rs = augmented_Lagrangian_method(M, cmo, p_; evaluation=evaluation, kwargs...)
return _ensure_matching_output(p, rs)
end
function augmented_Lagrangian_method(
M::AbstractManifold, cmo::O, p=rand(M); kwargs...
M::AbstractManifold, cmo::O, p=rand(M);
kwargs...
kellertuer marked this conversation as resolved.
Show resolved Hide resolved
) where {O<:Union{ConstrainedManifoldObjective,AbstractDecoratedManifoldObjective}}
q = copy(M, p)
return augmented_Lagrangian_method!(M, cmo, q; kwargs...)
end
function augmented_Lagrangian_method(
M::AbstractManifold,
f::TF,
grad_f::TGF,
p::Number;
evaluation::AbstractEvaluationType=AllocatingEvaluation(),
g=nothing,
grad_g=nothing,
grad_h=nothing,
h=nothing,
kwargs...,
) where {TF,TGF}
q = [p]
f_(M, p) = f(M, p[])
grad_f_ = _to_mutating_gradient(grad_f, evaluation)
g_ = isnothing(g) ? nothing : (M, p) -> g(M, p[])
grad_g_ = isnothing(grad_g) ? nothing : _to_mutating_gradient(grad_g, evaluation)
h_ = isnothing(h) ? nothing : (M, p) -> h(M, p[])
grad_h_ = isnothing(grad_h) ? nothing : _to_mutating_gradient(grad_h, evaluation)
cmo = ConstrainedManifoldObjective(
f_, grad_f_, g_, grad_g_, h_, grad_h_; evaluation=evaluation
)
rs = augmented_Lagrangian_method(M, cmo, q; evaluation=evaluation, kwargs...)
return (typeof(q) == typeof(rs)) ? rs[] : rs
end

@doc raw"""
augmented_Lagrangian_method!(M, f, grad_f, p=rand(M); kwargs...)
Expand Down
24 changes: 6 additions & 18 deletions src/solvers/conjugate_gradient_descent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,24 +94,12 @@ end
function conjugate_gradient_descent(
M::AbstractManifold, f::TF, grad_f::TDF, p; evaluation=AllocatingEvaluation(), kwargs...
) where {TF,TDF}
mgo = ManifoldGradientObjective(f, grad_f; evaluation=evaluation)
return conjugate_gradient_descent(M, mgo, p; evaluation=evaluation, kwargs...)
end
function conjugate_gradient_descent(
M::AbstractManifold,
f::TF,
grad_f::TDF,
p::Number;
evaluation::AbstractEvaluationType=AllocatingEvaluation(),
kwargs...,
) where {TF,TDF}
# redefine initial point
q = [p]
f_(M, p) = f(M, p[])
grad_f_ = _to_mutating_gradient(grad_f, evaluation)
rs = conjugate_gradient_descent(M, f_, grad_f_, q; evaluation=evaluation, kwargs...)
#return just a number if the return type is the same as the type of q
return (typeof(q) == typeof(rs)) ? rs[] : rs
p_ = _ensure_mutating_variable(p)
f_ = _ensure_mutating_cost(f, p)
grad_f_ = _ensure_mutating_gradient(grad_f, p, evaluation)
mgo = ManifoldGradientObjective(f_, grad_f_; evaluation=evaluation)
rs = conjugate_gradient_descent(M, mgo, p_; evaluation=evaluation, kwargs...)
return _ensure_matching_output(p, rs)
end
function conjugate_gradient_descent(
M::AbstractManifold, mgo::O, p=rand(M); kwargs...
Expand Down
27 changes: 6 additions & 21 deletions src/solvers/cyclic_proximal_point.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,27 +57,12 @@ function cyclic_proximal_point(
evaluation::AbstractEvaluationType=AllocatingEvaluation(),
kwargs...,
)
mpo = ManifoldProximalMapObjective(f, proxes_f; evaluation=evaluation)
return cyclic_proximal_point(M, mpo, p; evaluation=evaluation, kwargs...)
end
function cyclic_proximal_point(
M::AbstractManifold,
f,
proxes_f::Union{Tuple,AbstractVector},
p::Number;
evaluation::AbstractEvaluationType=AllocatingEvaluation(),
kwargs...,
)
q = [p]
f_(M, p) = f(M, p[])
if evaluation isa AllocatingEvaluation
proxes_f_ = [(M, λ, p) -> [pf(M, λ, p[])] for pf in proxes_f]
else
proxes_f_ = [(M, q, λ, p) -> (q .= [pf(M, λ, p[])]) for pf in proxes_f]
end
rs = cyclic_proximal_point(M, f_, proxes_f_, q; evaluation=evaluation, kwargs...)
#return just a number if the return type is the same as the type of q
return (typeof(q) == typeof(rs)) ? rs[] : rs
p_ = _ensure_mutating_variable(p)
f_ = _ensure_mutating_cost(f, p)
proxes_f_ = [_ensure_mutating_prox(prox_f, p, evaluation) for prox_f in proxes_f]
mpo = ManifoldProximalMapObjective(f_, proxes_f_; evaluation=evaluation)
rs = cyclic_proximal_point(M, mpo, p_; evaluation=evaluation, kwargs...)
return _ensure_matching_output(p, rs)
end
function cyclic_proximal_point(
M::AbstractManifold, mpo::O, p; kwargs...
Expand Down
Loading
Loading