diff --git a/src/PowerManifold.jl b/src/PowerManifold.jl index c188c713..95ca9e01 100644 --- a/src/PowerManifold.jl +++ b/src/PowerManifold.jl @@ -172,6 +172,52 @@ function Base.:^( return PowerManifold(M, size...) end +function allocate_as( + M::PowerManifold{ + 𝔽, + TM, + TSize, + <:Union{NestedPowerRepresentation,NestedReplacingPowerRepresentation}, + }, +) where {𝔽,TM<:AbstractManifold{𝔽},TSize} + return [allocate_as(M.manifold) for _ in get_iterator(M)] +end +function allocate_as( + M::PowerManifold{ + 𝔽, + TM, + TSize, + <:Union{NestedPowerRepresentation,NestedReplacingPowerRepresentation}, + }, + ::Type{<:Array{U}}, +) where {𝔽,TM<:AbstractManifold{𝔽},TSize,U} + return [allocate_as(M.manifold, U) for _ in get_iterator(M)] +end + +function allocate_as( + M::PowerManifold{ + 𝔽, + TM, + TSize, + <:Union{NestedPowerRepresentation,NestedReplacingPowerRepresentation}, + }, + ft::TangentSpaceType, +) where {𝔽,TM<:AbstractManifold{𝔽},TSize} + return [allocate_as(M.manifold, ft) for _ in get_iterator(M)] +end +function allocate_as( + M::PowerManifold{ + 𝔽, + TM, + TSize, + <:Union{NestedPowerRepresentation,NestedReplacingPowerRepresentation}, + }, + ft::TangentSpaceType, + ::Type{<:Array{U}}, +) where {𝔽,TM<:AbstractManifold{𝔽},TSize,U} + return [allocate_as(M.manifold, ft, U) for _ in get_iterator(M)] +end + """ _allocate_access_nested(M::PowerManifoldNested, y, i) @@ -194,7 +240,7 @@ function allocate_result(M::PowerManifoldNested, f, x...) ] end end -# avoid ambituities - though usually not used +# avoid ambiguities - though usually not used function allocate_result( M::PowerManifoldNested, f::typeof(get_coordinates), diff --git a/test/power.jl b/test/power.jl index 390efa72..9d8359db 100644 --- a/test/power.jl +++ b/test/power.jl @@ -90,6 +90,26 @@ end allocate(M, p) isa Vector{SMatrix{2,2,Float64,4}} end + @testset "allocate_as" begin + M = ManifoldsBase.DefaultManifold(2, 2) + N = PowerManifold(M, NestedReplacingPowerRepresentation(), 2) + p = allocate_as(N) + @test p isa Vector{Matrix{Float64}} + @test size(p) == (2,) + + p = allocate_as(N, Vector{Matrix{Float32}}) + @test p isa Vector{Matrix{Float32}} + @test size(p) == (2,) + + X = allocate_as(N, TangentSpaceType()) + @test X isa Vector{Matrix{Float64}} + @test size(X) == (2,) + + X = allocate_as(N, TangentSpaceType(), Vector{Matrix{Float32}}) + @test X isa Vector{Matrix{Float32}} + @test size(X) == (2,) + end + for PowerRepr in [NestedPowerRepresentation, NestedReplacingPowerRepresentation] @testset "PowerManifold with $(PowerRepr)" begin M = ManifoldsBase.DefaultManifold(3) diff --git a/test/product_manifold.jl b/test/product_manifold.jl index 8030d7b4..d1dbe785 100644 --- a/test/product_manifold.jl +++ b/test/product_manifold.jl @@ -1,4 +1,3 @@ -using Revise using Test using ManifoldsBase using ManifoldsBase: DefaultManifold, submanifold_component, submanifold_components