Skip to content

Commit

Permalink
starts product manifolds with an RecursiveArrayToolsExtension
Browse files Browse the repository at this point in the history
  • Loading branch information
kellertuer committed Nov 25, 2024
1 parent 1b8e5ee commit 0680095
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 8 deletions.
7 changes: 7 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,18 @@ Manifolds = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e"
ManifoldsBase = "3362f125-f0bb-47a3-aa74-596ffd7ef2fb"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[weakdeps]
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"

[extensions]
LieGroupsRecursiveArrayToolsExt = "RecursiveArrayTools"

[compat]
Aqua = "0.8"
LinearAlgebra = "1.6"
Manifolds = "0.10.5"
ManifoldsBase = "0.15.20"
RecursiveArrayTools = "2, 3"
Random = "1.6"
Test = "1.6"
julia = "1.6"
Expand Down
15 changes: 15 additions & 0 deletions ext/LieGroupsRecursiveArrayToolsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
module LieGroupsRecursiveArrayToolsExt

using LieGroups, ManifoldsBase, RecursiveArrayTools

using LieGroups: identity_element, LieGroup

function ManifoldsBase.allocate_result(
G::LieGroup{𝔽,Op,M}, ::typeof(identity_element)
) where {𝔽,Op,<:ManifoldsBase.ProductManifold}
M = base_manifold(G)
Ls = LieGroup.(M.manifolds, G.op.operations)
ps = ManifoldsBase.allocate_result.(Ls, Ref(identity_element))
return ArrayPartition(ps...)
end
end
54 changes: 54 additions & 0 deletions src/groups/product_group.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,48 @@ function ProductLieGroup(G::LieGroup, H::LieGroup)
return LieGroup(G.manifold × H.manifold, G.op × H.op)
end

function _compose!(
PrG::LieGroup{𝔽,Op,M}, k, g, h
) where {𝔽,Op<:ProductGroupOperation,M<:ManifoldsBase.ProductManifold}
map(
compose!,
LieGroup.(PrG.manifold.manifolds, PrG.op.operations),
submanifold_components(PrG.manifold, k),
submanifold_components(PrG.manifold, g),
submanifold_components(PrG.manifold, h),
)
return k
end

function ManifoldsBase.check_size(
PrG::LieGroup{𝔽,Op,M}, g
) where {𝔽,Op<:ProductGroupOperation,M<:ManifoldsBase.ProductManifold}
return ManifoldsBase.check_size(PrG.manifold, g)
end
function ManifoldsBase.check_size(
::LieGroup{𝔽,Op,M}, ::Identity
) where {𝔽,Op<:ProductGroupOperation,M<:ManifoldsBase.ProductManifold}
return nothing
end
function ManifoldsBase.check_size(
PrG::LieGroup{𝔽,Op,M}, g, X
) where {𝔽,Op<:ProductGroupOperation,M<:ManifoldsBase.ProductManifold}
return ManifoldsBase.check_size(PrG.manifold, g, X)
end

function conjugate!(
PrG::LieGroup{𝔽,Op,M}, k, g, h
) where {𝔽,Op<:ProductGroupOperation,M<:ManifoldsBase.ProductManifold}
map(
conjugate,
LieGroup.(PrG.manifold.manifolds, PrG.op.operations),
submanifold_components(PrG.manifold, k),
submanifold_components(PrG.manifold, g),
submanifold_components(PrG.manifold, h),
)
return k
end

@doc raw"""
cross(G, H)
G × H
Expand All @@ -75,6 +117,18 @@ function LinearAlgebra.cross(G::LieGroup, H::LieGroup)
return ProductLieGroup(G, H)
end

function inv!(
PrG::LieGroup{𝔽,Op,M}, h, g
) where {𝔽,Op<:ProductGroupOperation,M<:ManifoldsBase.ProductManifold}
map(
inv!,
LieGroup.(PrG.manifold.manifolds, PrG.op.operations),
submanifold_components(M, h),
submanifold_components(M, g),
)
return h
end

function Base.show(
io::IO, G::LieGroup{𝔽,<:ProductGroupOperation,<:ManifoldsBase.ProductManifold}
) where {𝔽}
Expand Down
18 changes: 10 additions & 8 deletions test/groups/test_product_group.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
using LieGroups, Test, ManifoldsBase
using LieGroups, Test, ManifoldsBase, RecursiveArrayTools

s = joinpath(@__DIR__, "..", "LieGroupsTestSuite.jl")
!(s in LOAD_PATH) && (push!(LOAD_PATH, s))
using LieGroupsTestSuite

@testset "Generic product Lie group" begin
M = LieGroupsTestSuite.DummyManifold()
op = LieGroupsTestSuite.DummyOperation()
G = LieGroup(M, op)
G2 = G × G
G = TranslationGroup(2) × TranslationGroup(2)
g, h = ArrayPartition([1.0, 0.0], [0.0, 3.0]), ArrayPartition([0.0, 1.0], [2.0, 0.0])
X, Y = ArrayPartition([0.0, 0.1], [0.2, 0.0]), ArrayPartition([0.1, 0.2], [0.0, 0.3])

properties = Dict(
:Name => "The Product Manifold",
# :Rng => Random.MersenneTwister(),
:Points => [g, h],
:Vectors => [X, Y],
:Functions => [
# compose,
compose,
# conjugate,
# diff_conjugate,
# diff_inv,
Expand All @@ -34,11 +35,12 @@ using LieGroupsTestSuite
],
)
expectations = Dict(
:repr => "ProductLieGroup(LieGroupsTestSuite.DummyManifold() × LieGroupsTestSuite.DummyManifold(), LieGroupsTestSuite.DummyOperation() × LieGroupsTestSuite.DummyOperation())",
:repr => "ProductLieGroup(Euclidean(2; field=ℝ) × Euclidean(2; field=ℝ), AdditionGroupOperation() × AdditionGroupOperation())",
)
test_lie_group(G2, properties, expectations)
test_lie_group(G, properties, expectations)

@testset "Product Operation generators" begin
op = LieGroupsTestSuite.DummyOperation()
op2 = LieGroupsTestSuite.DummySecondOperation()
O1 = op × op2
O2 = op2 × op
Expand Down

0 comments on commit 0680095

Please sign in to comment.