Skip to content

Commit

Permalink
Add test coverage.
Browse files Browse the repository at this point in the history
  • Loading branch information
kellertuer committed Aug 23, 2024
1 parent 7c8a382 commit 79382e7
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 7 deletions.
12 changes: 5 additions & 7 deletions src/plans/conjugate_gradient_plan.jl
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ struct DaiYuanCoefficientRule{VTM<:AbstractVectorTransportMethod} <: DirectionUp
vector_transport_method::VTM
end
function DaiYuanCoefficientRule(
M::AbstractManifold=ManifoldsBase.DefaultManifold();
M::AbstractManifold=DefaultManifold();
vector_transport_method::VTM=default_vector_transport_method(M),
) where {VTM<:AbstractVectorTransportMethod}
return DaiYuanCoefficientRule{VTM}(vector_transport_method)
Expand Down Expand Up @@ -394,7 +394,7 @@ mutable struct HagerZhangCoefficientRule{VTM<:AbstractVectorTransportMethod} <:
vector_transport_method::VTM
end
function HagerZhangCoefficientRule(
M::AbstractManifold=ManifoldsBase.DefaultManifold();
M::AbstractManifold=DefaultManifold();
vector_transport_method::VTM=default_vector_transport_method(M),
) where {VTM<:AbstractVectorTransportMethod}
return HagerZhangCoefficientRule{VTM}(vector_transport_method)
Expand Down Expand Up @@ -503,7 +503,7 @@ struct HestenesStiefelCoefficientRule{VTM<:AbstractVectorTransportMethod} <:
vector_transport_method::VTM
end
function HestenesStiefelCoefficientRule(
M::AbstractManifold=ManifoldsBase.DefaultManifold();
M::AbstractManifold=DefaultManifold();
vector_transport_method::VTM=default_vector_transport_method(M),
) where {VTM<:AbstractVectorTransportMethod}
return HestenesStiefelCoefficientRule{VTM}(vector_transport_method)
Expand Down Expand Up @@ -690,7 +690,7 @@ struct PolakRibiereCoefficientRule{VTM<:AbstractVectorTransportMethod} <:
vector_transport_method::VTM
end
function PolakRibiereCoefficientRule(
M::AbstractManifold=ManifoldsBase.DefaultManifold();
M::AbstractManifold=DefaultManifold();
vector_transport_method::VTM=default_vector_transport_method(M),
) where {VTM<:AbstractVectorTransportMethod}
return PolakRibiereCoefficientRule{VTM}(vector_transport_method)
Expand Down Expand Up @@ -859,9 +859,7 @@ end
function ConjugateGradientBealeRestartRule(
direction_update::Union{DirectionUpdateRule,ManifoldDefaultsFactory}; kwargs...
)
return ConjugateGradientBealeRestartRule(
ManifoldsBase.DefaultManifold(), direction_update; kwargs...
)
return ConjugateGradientBealeRestartRule(DefaultManifold(), direction_update; kwargs...)
end

@inline function update_rule_storage_points(dur::ConjugateGradientBealeRestartRule)
Expand Down
1 change: 1 addition & 0 deletions src/plans/manifold_default_factory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ function (mdf::ManifoldDefaultsFactory{T,<:AbstractManifold})() where {T}
if mdf.constructor_requires_manifold
return T(mdf.M, mdf.args...; mdf.kwargs...)
else
println("frmops $(mdf.args...)")
return T(mdf.args...; mdf.kwargs...)
end
end
Expand Down
2 changes: 2 additions & 0 deletions test/plans/test_conjugate_gradient_plan.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ Manopt.update_rule_storage_vectors(::DummyCGCoeff) = Tuple{}
)
s1 = "Manopt.ConjugateGradientBealeRestartRule(Manopt.ConjugateDescentCoefficientRule(); threshold=0.2, vector_transport_method=$(pt))"
@test repr(cgbr) == s1
cgbr2 = Manopt.ConjugateGradientBealeRestartRule(ConjugateDescentCoefficient())
@test cgbr.threshold == cgbr.threshold
@test repr(LiuStoreyCoefficient(M)()) ==
"Manopt.LiuStoreyCoefficientRule(; vector_transport_method=$pt)"
end
Expand Down
15 changes: 15 additions & 0 deletions test/plans/test_defaults_factory.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
using Manopt, Manifolds, Test

# A rule taht does not need a manifold but has defaults
struct FactoryDummyRule{R<:Real}
t::R
FactoryDummyRule(; t::R=1.0) where {R<:Real} = new{R}(t)
end

@testset "ManifoldsDefaultFactory" begin
fdr = Manopt.ManifoldDefaultsFactory(
FactoryDummyRule, Sphere(2); requires_manifold=false, t=2.0
)
@test fdr().t == 2.0
@test fdr(Euclidean(2)).t == 2.0
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ include("utils/example_tasks.jl")
include("plans/test_primal_dual_plan.jl")
include("plans/test_proximal_plan.jl")
include("plans/test_higher_order_primal_dual_plan.jl")
include("plans/test_defaults_factory.jl")
include("plans/test_record.jl")
include("plans/test_stepsize.jl")
include("plans/test_stochastic_gradient_plan.jl")
Expand Down

0 comments on commit 79382e7

Please sign in to comment.