Skip to content

Commit

Permalink
Fix fallbacks, add code cov.
Browse files Browse the repository at this point in the history
  • Loading branch information
kellertuer committed Nov 2, 2024
1 parent 187c78c commit 95dfccd
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 24 deletions.
14 changes: 6 additions & 8 deletions src/Lie_algebra/Lie_algebra_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,24 +27,22 @@ function LieAlgebra(G::LieGroup{𝔽,O}) where {𝔽,O<:AbstractGroupOperation}
end

function ManifoldsBase.get_coordinates(𝔤::LieAlgebra, X, B::ManifoldsBase.AbstractBasis)
c = ManifoldsBase.allocate_result(B, get_coordinates, X, B)
get_coordinates!(𝔤, c, X, B)
return X
G = 𝔤.manifold
return get_coordinates(base_manifold(G), identity_element(G), X, B)
end
function ManifoldsBase.get_coordinates!(𝔤::LieAlgebra, c, X, B::ManifoldsBase.AbstractBasis)
G = 𝔤.manifold
get_coordinates!(base_manifold(𝔤), c, identity_element(G), X, B)
get_coordinates!(base_manifold(G), c, identity_element(G), X, B)
return c
end

function ManifoldsBase.get_vector(𝔤::LieAlgebra, c, B::ManifoldsBase.AbstractBasis)
X = zero_vector(𝔤)
get_vector!(𝔤, X, c, B)
return X
G = 𝔤.manifold
return get_vector(base_manifold(G), identity_element(G), c, B)
end
function ManifoldsBase.get_vector!(𝔤::LieAlgebra, X, c, B::ManifoldsBase.AbstractBasis)
G = 𝔤.manifold
get_vector!(base_manifold(𝔤), X, identity_element(G), c, B)
get_vector!(base_manifold(G), X, identity_element(G), c, B)
return X
end

Expand Down
32 changes: 16 additions & 16 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -447,11 +447,13 @@ ManifoldsBase.get_coordinates(G::LieGroup, g, X, B::ManifoldsBase.AbstractBasis)
ManifoldsBase.get_coordinates!(G::LieGroup, c, g, X, B::ManifoldsBase.AbstractBasis)

function get_coordinates_lie(G::LieGroup, g, X, N)
return get_coordinates(G, identity_element(G), X, DefaultOrthogonalBasis(N))
return get_coordinates(
base_manifold(G), identity_element(G), X, ManifoldsBase.DefaultOrthogonalBasis(N)
)
end
function get_coordinates_lie!(G::LieGroup, Y, g, X, N)
return get_coordinates!(
base_manifold(G), Y, identity_element(G), X, DefaultOrthogonalBasis(N)
base_manifold(G), Y, identity_element(G), X, ManifoldsBase.DefaultOrthogonalBasis(N)
)
end

Expand Down Expand Up @@ -483,11 +485,13 @@ ManifoldsBase.get_vector(G::LieGroup, g, c, B::ManifoldsBase.AbstractBasis)
ManifoldsBase.get_vector!(G::LieGroup, X, g, c, B::ManifoldsBase.AbstractBasis)

@inline function get_vector_lie(G::LieGroup, g, c, N)
return get_vector(base_manifold(G), identity_element(G), c, DefaultOrthogonalBasis(N))
return get_vector(
base_manifold(G), identity_element(G), c, ManifoldsBase.DefaultOrthogonalBasis(N)
)
end
@inline function get_vector_lie!(G::LieGroup, Y, g, c, N)
return get_vector!(
base_manifold(G), Y, identity_element(G), c, DefaultOrthogonalBasis(N)
base_manifold(G), Y, identity_element(G), c, ManifoldsBase.DefaultOrthogonalBasis(N)
)
end

Expand All @@ -514,15 +518,14 @@ Technically `hat` is a specific case of [`get_vector`](@ref) and is implemented

# function hat end
@doc "$(_doc_hat)"
function hat(G::LieGroup, c)
X = zero_vector(LieAlgebra(G))
return hat!(G, X, c)
function hat(G::LieGroup{𝔽}, c) where {𝔽}
return get_vector_lie(G, Identity(G), c, 𝔽)
end

# function hat! end
@doc "$(_doc_hat)"
function hat!(G::LieGroup, X, c)
get_vector!(G, X, identity_element(G), c, LieAlgebraOrthogonalBasis())
function hat!(G::LieGroup{𝔽}, X, c) where {𝔽}
get_vector_lie!(G, X, Identity(G), c, 𝔽)
return X
end

Expand Down Expand Up @@ -845,17 +848,14 @@ Technically `hat` is a specific case of [`get_coordinates`](@ref) and is impleme

# function vee end
@doc "$(_doc_vee)"
function vee(G::LieGroup, X)
c = ManifoldsBase.allocate_result(G, vee, X)
return vee!(G, c, X)
function vee(G::LieGroup{𝔽}, X) where {𝔽}
return get_coordinates_lie(G, Identity(G), X, 𝔽)
end

# function vee! end
@doc "$(_doc_vee)"
function vee!(G::LieGroup, c, X)
get_coordinates!(
G, c, identity_element(G), X, LieAlgebraOrthogonalBasis(ManifoldsBase.ℝ)
)
function vee!(G::LieGroup{𝔽}, c, X) where {𝔽}
get_coordinates_lie!(G, c, Identity(G), X, 𝔽)
return c
end

Expand Down
22 changes: 22 additions & 0 deletions test/test_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,25 @@ using LieGroupsTestSuite
@test_throws MethodError log!(G, X, e, g)
end
end
@testset "Generic Lie Algebra Interface functions" begin
M = ManifoldsBase.DefaultManifold(2)
op = AdditionGroupOperation()
G = LieGroup(M, op)
𝔤 = LieAlgebra(G)
@testset "Generic get_coordinates/get_vector passthrough on 𝔤" begin
B = DefaultOrthonormalBasis()
p = [1.0, 2.0]
q = [0.0, 0.0]
# coordinates and vector on 𝔤 are here the same as the ones on M at 0
X = [1.0, 0.0]
@test get_coordinates(𝔤, X, B) == get_coordinates(M, q, X, B)
Y = copy(X)
@test get_coordinates!(𝔤, Y, X, B) == get_coordinates!(M, Y, q, X, B)
@test X == Y
c = [0.0, 1.0]
@test get_vector(𝔤, c, B) == get_vector(M, q, c, B)
d = copy(c)
@test get_vector!(𝔤, d, c, B) == get_vector!(M, d, q, c, B)
@test c == d
end
end

0 comments on commit 95dfccd

Please sign in to comment.