From 79382e7bcd2b26f4692ff00d93003a7baf5e6a5b Mon Sep 17 00:00:00 2001 From: Ronny Bergmann Date: Fri, 23 Aug 2024 18:17:24 +0200 Subject: [PATCH] Add test coverage. --- src/plans/conjugate_gradient_plan.jl | 12 +++++------- src/plans/manifold_default_factory.jl | 1 + test/plans/test_conjugate_gradient_plan.jl | 2 ++ test/plans/test_defaults_factory.jl | 15 +++++++++++++++ test/runtests.jl | 1 + 5 files changed, 24 insertions(+), 7 deletions(-) create mode 100644 test/plans/test_defaults_factory.jl diff --git a/src/plans/conjugate_gradient_plan.jl b/src/plans/conjugate_gradient_plan.jl index c14e7d7193..156992eff8 100644 --- a/src/plans/conjugate_gradient_plan.jl +++ b/src/plans/conjugate_gradient_plan.jl @@ -233,7 +233,7 @@ struct DaiYuanCoefficientRule{VTM<:AbstractVectorTransportMethod} <: DirectionUp vector_transport_method::VTM end function DaiYuanCoefficientRule( - M::AbstractManifold=ManifoldsBase.DefaultManifold(); + M::AbstractManifold=DefaultManifold(); vector_transport_method::VTM=default_vector_transport_method(M), ) where {VTM<:AbstractVectorTransportMethod} return DaiYuanCoefficientRule{VTM}(vector_transport_method) @@ -394,7 +394,7 @@ mutable struct HagerZhangCoefficientRule{VTM<:AbstractVectorTransportMethod} <: vector_transport_method::VTM end function HagerZhangCoefficientRule( - M::AbstractManifold=ManifoldsBase.DefaultManifold(); + M::AbstractManifold=DefaultManifold(); vector_transport_method::VTM=default_vector_transport_method(M), ) where {VTM<:AbstractVectorTransportMethod} return HagerZhangCoefficientRule{VTM}(vector_transport_method) @@ -503,7 +503,7 @@ struct HestenesStiefelCoefficientRule{VTM<:AbstractVectorTransportMethod} <: vector_transport_method::VTM end function HestenesStiefelCoefficientRule( - M::AbstractManifold=ManifoldsBase.DefaultManifold(); + M::AbstractManifold=DefaultManifold(); vector_transport_method::VTM=default_vector_transport_method(M), ) where {VTM<:AbstractVectorTransportMethod} return HestenesStiefelCoefficientRule{VTM}(vector_transport_method) @@ -690,7 +690,7 @@ struct PolakRibiereCoefficientRule{VTM<:AbstractVectorTransportMethod} <: vector_transport_method::VTM end function PolakRibiereCoefficientRule( - M::AbstractManifold=ManifoldsBase.DefaultManifold(); + M::AbstractManifold=DefaultManifold(); vector_transport_method::VTM=default_vector_transport_method(M), ) where {VTM<:AbstractVectorTransportMethod} return PolakRibiereCoefficientRule{VTM}(vector_transport_method) @@ -859,9 +859,7 @@ end function ConjugateGradientBealeRestartRule( direction_update::Union{DirectionUpdateRule,ManifoldDefaultsFactory}; kwargs... ) - return ConjugateGradientBealeRestartRule( - ManifoldsBase.DefaultManifold(), direction_update; kwargs... - ) + return ConjugateGradientBealeRestartRule(DefaultManifold(), direction_update; kwargs...) end @inline function update_rule_storage_points(dur::ConjugateGradientBealeRestartRule) diff --git a/src/plans/manifold_default_factory.jl b/src/plans/manifold_default_factory.jl index 30f737525f..2aa389c760 100644 --- a/src/plans/manifold_default_factory.jl +++ b/src/plans/manifold_default_factory.jl @@ -80,6 +80,7 @@ function (mdf::ManifoldDefaultsFactory{T,<:AbstractManifold})() where {T} if mdf.constructor_requires_manifold return T(mdf.M, mdf.args...; mdf.kwargs...) else + println("frmops $(mdf.args...)") return T(mdf.args...; mdf.kwargs...) end end diff --git a/test/plans/test_conjugate_gradient_plan.jl b/test/plans/test_conjugate_gradient_plan.jl index b7fee4bee9..18b14258fd 100644 --- a/test/plans/test_conjugate_gradient_plan.jl +++ b/test/plans/test_conjugate_gradient_plan.jl @@ -59,6 +59,8 @@ Manopt.update_rule_storage_vectors(::DummyCGCoeff) = Tuple{} ) s1 = "Manopt.ConjugateGradientBealeRestartRule(Manopt.ConjugateDescentCoefficientRule(); threshold=0.2, vector_transport_method=$(pt))" @test repr(cgbr) == s1 + cgbr2 = Manopt.ConjugateGradientBealeRestartRule(ConjugateDescentCoefficient()) + @test cgbr.threshold == cgbr.threshold @test repr(LiuStoreyCoefficient(M)()) == "Manopt.LiuStoreyCoefficientRule(; vector_transport_method=$pt)" end diff --git a/test/plans/test_defaults_factory.jl b/test/plans/test_defaults_factory.jl new file mode 100644 index 0000000000..409ab128a7 --- /dev/null +++ b/test/plans/test_defaults_factory.jl @@ -0,0 +1,15 @@ +using Manopt, Manifolds, Test + +# A rule taht does not need a manifold but has defaults +struct FactoryDummyRule{R<:Real} + t::R + FactoryDummyRule(; t::R=1.0) where {R<:Real} = new{R}(t) +end + +@testset "ManifoldsDefaultFactory" begin + fdr = Manopt.ManifoldDefaultsFactory( + FactoryDummyRule, Sphere(2); requires_manifold=false, t=2.0 + ) + @test fdr().t == 2.0 + @test fdr(Euclidean(2)).t == 2.0 +end diff --git a/test/runtests.jl b/test/runtests.jl index da48b4ce85..e254f85b13 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -23,6 +23,7 @@ include("utils/example_tasks.jl") include("plans/test_primal_dual_plan.jl") include("plans/test_proximal_plan.jl") include("plans/test_higher_order_primal_dual_plan.jl") + include("plans/test_defaults_factory.jl") include("plans/test_record.jl") include("plans/test_stepsize.jl") include("plans/test_stochastic_gradient_plan.jl")