Skip to content

Commit

Permalink
flexibilize types for partial_trace (#109)
Browse files Browse the repository at this point in the history
* flexibilize types for partial_trace
  • Loading branch information
araujoms authored Dec 19, 2023
1 parent fc2ec75 commit 50e551e
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions src/math_util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,9 @@ julia> partial_trace(ρ1⊗ρ2, [2, 2], [1])
0.2 0.6
```
"""
function partial_trace::Matrix, sys_dim::AbstractVector{Int}, dim_2_keep::AbstractVector{Int})
partial_trace::AbstractMatrix, sys_dim::AbstractVector{<:Integer}, dim_2_keep::Integer) = partial_trace(ρ,sys_dim,[dim_2_keep])

function partial_trace::AbstractMatrix, sys_dim::AbstractVector{<:Integer}, dim_2_keep::AbstractVector{<:Integer})
size(ρ, 1) != size(ρ, 2) && throw(ArgumentError("ρ is not a square matrix."))
prod(sys_dim) != size(ρ, 1) && throw(ArgumentError("System dimensions do not multiply to density matrix dimension."))

Expand All @@ -226,7 +228,7 @@ function partial_trace(ρ::Matrix, sys_dim::AbstractVector{Int}, dim_2_keep::Abs
re_dim[dim_2_trace] .= 1
tr_dim = copy(sys_dim)
tr_dim[dim_2_keep] .= 1
res = zeros(ComplexF64, re_dim..., re_dim...)
res = zeros(eltype(ρ), re_dim..., re_dim...)
for I in CartesianIndices(size(res))
for k in CartesianIndices((tr_dim...,))
delta = CartesianIndex(k, k)
Expand Down Expand Up @@ -358,4 +360,4 @@ function haar_unitary(dim)
L = diag(r)
L=L./abs.(L)
q*diagm(0=>L)
end
end

0 comments on commit 50e551e

Please sign in to comment.