From 22aaf3a1abdd06aff25ba44248218ef5bb6ff94f Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Tue, 10 Oct 2023 16:04:45 +0200 Subject: [PATCH] tests for Sasaki retraction --- src/point_vector_fallbacks.jl | 20 ++++++++++++++++++++ src/retractions.jl | 9 +++++++++ test/default_manifold.jl | 11 +++++++++++ 3 files changed, 40 insertions(+) diff --git a/src/point_vector_fallbacks.jl b/src/point_vector_fallbacks.jl index 5b36dab5..797e67cb 100644 --- a/src/point_vector_fallbacks.jl +++ b/src/point_vector_fallbacks.jl @@ -331,6 +331,26 @@ macro default_manifold_fallbacks(TM, TP, TV, pfield::Symbol, vfield::Symbol) ManifoldsBase.retract_embedded!(M, q.$pfield, p.$pfield, X.$vfield, t, m) return q end + function ManifoldsBase.retract_sasaki( + M::$TM, + p::$TP, + X::$TV, + t::Number, + m::SasakiRetraction, + ) + return $TP(ManifoldsBase.retract_sasaki(M, p.$pfield, X.$vfield, t, m)) + end + function ManifoldsBase.retract_sasaki!( + M::$TM, + q::$TP, + p::$TP, + X::$TV, + t::Number, + m::SasakiRetraction, + ) + ManifoldsBase.retract_sasaki!(M, q.$pfield, p.$pfield, X.$vfield, t, m) + return q + end end, ) for f_postfix in [:polar, :project, :qr, :softmax] diff --git a/src/retractions.jl b/src/retractions.jl index ea74671e..62071e06 100644 --- a/src/retractions.jl +++ b/src/retractions.jl @@ -1074,6 +1074,15 @@ retract_softmax!(M::AbstractManifold, q, p, X, t::Number) function retract_softmax! end +""" + retract_sasaki!(M::AbstractManifold, q, p, X, t::Number, m::SasakiRetraction) + +Compute the in-place variant of the [`SasakiRetraction`](@ref) `m`. +""" +retract_pade!(M::AbstractManifold, q, p, X, t::Number, m::SasakiRetraction) + +function retract_sasaki! end + @doc raw""" retract(M::AbstractManifold, p, X, method::AbstractRetractionMethod=default_retraction_method(M, typeof(p))) retract(M::AbstractManifold, p, X, t::Number=1, method::AbstractRetractionMethod=default_retraction_method(M, typeof(p))) diff --git a/test/default_manifold.jl b/test/default_manifold.jl index 73e34b3f..4c8c1e22 100644 --- a/test/default_manifold.jl +++ b/test/default_manifold.jl @@ -181,6 +181,16 @@ function ManifoldsBase.retract_exp_ode!( return (q .= p .+ t .* X) end ManifoldsBase.retract_pade!(::DefaultManifold, q, p, X, t::Number, i) = (q .= p .+ t .* X) +function ManifoldsBase.retract_sasaki!( + ::DefaultManifold, + q, + p, + X, + t::Number, + ::SasakiRetraction, +) + return (q .= p .+ t .* X) +end ManifoldsBase.retract_softmax!(::DefaultManifold, q, p, X, t::Number) = (q .= p .+ t .* X) ManifoldsBase.get_embedding(M::DefaultManifold) = M # dummy embedding ManifoldsBase.inverse_retract_polar!(::DefaultManifold, Y, p, q) = (Y .= q .- p) @@ -723,6 +733,7 @@ Base.size(x::MatrixVectorTransport) = (size(x.m, 2),) ODEExponentialRetraction(PolarRetraction(), DefaultBasis()), PadeRetraction(2), EmbeddedRetraction(ExponentialRetraction()), + SasakiRetraction(5), ] @test retract(M, q, Y, retr) == DefaultPoint(q.value + Y.value) @test retract(M, q, Y, 0.5, retr) == DefaultPoint(q.value + 0.5 * Y.value)