diff --git a/src/math_util.jl b/src/math_util.jl index 69f84db..c718b33 100644 --- a/src/math_util.jl +++ b/src/math_util.jl @@ -212,7 +212,9 @@ julia> partial_trace(ρ1⊗ρ2, [2, 2], [1]) 0.2 0.6 ``` """ -function partial_trace(ρ::Matrix, sys_dim::AbstractVector{Int}, dim_2_keep::AbstractVector{Int}) +partial_trace(ρ::AbstractMatrix, sys_dim::AbstractVector{<:Integer}, dim_2_keep::Integer) = partial_trace(ρ,sys_dim,[dim_2_keep]) + +function partial_trace(ρ::AbstractMatrix, sys_dim::AbstractVector{<:Integer}, dim_2_keep::AbstractVector{<:Integer}) size(ρ, 1) != size(ρ, 2) && throw(ArgumentError("ρ is not a square matrix.")) prod(sys_dim) != size(ρ, 1) && throw(ArgumentError("System dimensions do not multiply to density matrix dimension.")) @@ -226,7 +228,7 @@ function partial_trace(ρ::Matrix, sys_dim::AbstractVector{Int}, dim_2_keep::Abs re_dim[dim_2_trace] .= 1 tr_dim = copy(sys_dim) tr_dim[dim_2_keep] .= 1 - res = zeros(ComplexF64, re_dim..., re_dim...) + res = zeros(eltype(ρ), re_dim..., re_dim...) for I in CartesianIndices(size(res)) for k in CartesianIndices((tr_dim...,)) delta = CartesianIndex(k, k) @@ -358,4 +360,4 @@ function haar_unitary(dim) L = diag(r) L=L./abs.(L) q*diagm(0=>L) -end \ No newline at end of file +end