Skip to content

Commit

Permalink
address review comment
Browse files Browse the repository at this point in the history
  • Loading branch information
mateuszbaran committed Oct 4, 2024
1 parent cf59553 commit f636904
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 8 deletions.
5 changes: 4 additions & 1 deletion src/DefaultManifold.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,10 @@ for fname in [:get_basis_orthonormal, :get_basis_orthogonal, :get_basis_default]
return CachedBasis($BT(N), [_euclidean_basis_vector(p, i) for i in eachindex(p)])
end
@eval function $fname(::DefaultManifold{ℂ}, p, N::ComplexNumbers)
return CachedBasis($BT(N), [_euclidean_basis_vector(p, i) for i in eachindex(p)])
return CachedBasis(
$BT(N),
[_euclidean_basis_vector(p, i, real) for i in eachindex(p)],
)
end
end
function get_basis_diagonalizing(M::DefaultManifold, p, B::DiagonalizingOrthonormalBasis)
Expand Down
12 changes: 7 additions & 5 deletions src/bases.jl
Original file line number Diff line number Diff line change
Expand Up @@ -406,15 +406,17 @@ function _dual_basis(
return DefaultOrthonormalBasis{𝔽}(TangentSpaceType())
end

function _euclidean_basis_vector(p::StridedArray, i)
X = zero(p)
# if `p` has complex eltype but you'd like to have real basis vectors,
# you can pass `real` as a third argument to get that
function _euclidean_basis_vector(p::StridedArray, i, eltype_transform = identity)
X = zeros(eltype_transform(eltype(p)), size(p)...)
X[i] = 1
return X
end
function _euclidean_basis_vector(p, i)
function _euclidean_basis_vector(p, i, eltype_transform = identity)
# when p is for example a SArray
X = similar(p)
copyto!(X, zero(p))
X = similar(p, eltype_transform(eltype(p)))
fill!(X, zero(eltype(X)))
X[i] = 1
return X
end
Expand Down
4 changes: 2 additions & 2 deletions test/default_manifold.jl
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,7 @@ using ManifoldsBaseTestUtils
p = [1.0im, 2.0im, -1.0im]
CB = get_basis(MC, p, DefaultOrthonormalBasis(ManifoldsBase.ℂ))
@test CB.data == [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]
@test CB.data isa Vector{Vector{ComplexF64}}
@test CB.data isa Vector{Vector{Float64}}
@test ManifoldsBase.coordinate_eltype(MC, p, ManifoldsBase.ℂ) === ComplexF64
@test ManifoldsBase.coordinate_eltype(MC, p, ManifoldsBase.ℝ) === Float64
CBR = get_basis(MC, p, DefaultOrthonormalBasis())
Expand Down Expand Up @@ -679,7 +679,7 @@ using ManifoldsBaseTestUtils
end

@testset "scalars" begin
M = DefaultManifold()
M = ManifoldsBase.DefaultManifold()
p = 1.0
X = 2.0
@test copy(M, p) === p
Expand Down

0 comments on commit f636904

Please sign in to comment.