Skip to content

Commit

Permalink
update ptrace and added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
apkille committed Jul 2, 2024
1 parent 194ebba commit 5997dfb
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 27 deletions.
5 changes: 3 additions & 2 deletions src/QSymbolicsBase/QSymbolicsBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ export SymQObj,QObj,
SScaled,SScaledBra,SScaledOperator,SScaledKet,
STensorBra,STensorKet,STensorOperator,
SZeroBra,SZeroKet,SZeroOperator,
SProjector,MixedState,IdentityOp,SInvOperator,
SApplyKet,SApplyBra,SMulOperator,SSuperOpApply,SCommutator,SAnticommutator,SDagger,SBraKet,SOuterKetBra,
MixedState,IdentityOp,
SProjector,SDagger,STrace,SPartialTrace,SInvOperator,
SApplyKet,SApplyBra,SMulOperator,SSuperOpApply,SCommutator,SAnticommutator,SBraKet,SOuterKetBra,
HGate,XGate,YGate,ZGate,CPHASEGate,CNOTGate,
XBasisState,YBasisState,ZBasisState,
NumberOp,CreateOp,DestroyOp,
Expand Down
2 changes: 1 addition & 1 deletion src/QSymbolicsBase/basic_ops_homogeneous.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ function Base.:(*)(xs::Symbolic{AbstractOperator}...)
end
end
Base.show(io::IO, x::SMulOperator) = print(io, join(map(string, arguments(x)),""))
basis(x::SMulOperator) = basis(x.terms)
basis(x::SMulOperator) = basis(first(x.terms))

"""Tensor product of quantum objects (kets, operators, or bras)
Expand Down
40 changes: 25 additions & 15 deletions src/QSymbolicsBase/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ Base.isequal(x::STrace, y::STrace) = isequal(x.op, y.op)
"""Partial trace over system i of a composite quantum system
```jldoctest
julia> @op 𝒪 SpinBasis(1//2)SpinBasis(1//2);
julia> @op 𝒪 SpinBasis(1//2)SpinBasis(1//2);
julia> op = ptrace(𝒪, 1)
tr1(𝒪)
Expand Down Expand Up @@ -180,21 +180,15 @@ function basis(x::SPartialTrace)
end
Base.show(io::IO, x::SPartialTrace) = print(io, "tr$(x.sys)($(x.obj))")
function ptrace(x::Symbolic{AbstractOperator}, s)
if isa(basis(x), CompositeBasis)
SPartialTrace(x, s)
ex = isexpr(x) ? qexpand(x) : x
if isa(ex, typeof(x))
if isa(basis(x), CompositeBasis)
SPartialTrace(x, s)
else
throw(ArgumentError("cannot take partial trace of a single quantum system"))
end
else
throw(ArgumentError("cannot take partial trace of a single quantum system"))
end
end
function ptrace(x::STensorOperator, s)
terms = arguments(x)
newterms = []
if isa(basis(terms[s]), CompositeBasis)
SPartial(x, s)
else
sys_op = terms[s]
new_terms = deleteat!(copy(terms), s)
isone(length(new_terms)) ? tr(sys_op)*first(new_terms) : tr(sys_op)*STensorOperator(new_terms)
ptrace(ex, s)
end
end
function ptrace(x::SAddOperator, s)
Expand All @@ -212,6 +206,22 @@ function ptrace(x::SAddOperator, s)
end
(+)(add_terms...)
end
function ptrace(x::STensorOperator, s)
ex = qexpand(x)
if isa(ex, SAddOperator)
ptrace(ex, s)
else
terms = arguments(ex)
newterms = []
if isa(basis(terms[s]), CompositeBasis)
SPartial(ex, s)
else
sys_op = terms[s]
new_terms = deleteat!(copy(terms), s)
isone(length(new_terms)) ? tr(sys_op)*first(new_terms) : tr(sys_op)*STensorOperator(new_terms)
end
end
end

"""Inverse Operator
Expand Down
5 changes: 3 additions & 2 deletions src/QSymbolicsBase/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,10 @@ RULES_EXPAND = [
@rule(+(~~ops) ~o1 => +(map(op -> op ~o1, ~~ops)...)),
@rule(~o1 * +(~~ops) => +(map(op -> ~o1 * op, ~~ops)...)),
@rule(+(~~ops) * ~o1 => +(map(op -> op * ~o1, ~~ops)...)),
@rule(+(~~ops) * ~o1 => +(map(op -> op * ~o1, ~~ops)...)),
@rule((~~ops1::_vecisa(Symbolic{AbstractBra})) * (~~ops2::_vecisa(Symbolic{AbstractKet})) => *(map(*, ~~ops1, ~~ops2)...)),
@rule((~~ops1::_vecisa(Symbolic{AbstractOperator})) * (~~ops2::_vecisa(Symbolic{AbstractOperator})) => (map(*, ~~ops1, ~~ops2)...)),
@rule((~~ops1::_vecisa(Symbolic{AbstractBra})) * (~~ops2::_vecisa(Symbolic{AbstractKet})) => *(map(*, ~~ops1, ~~ops2)...))
@rule(~o1::_isa(Symbolic{AbstractOperator}) * (~~ops) => (map(op -> ~o1 * op, ~~ops)...)),
@rule((~~ops) * ~o1::_isa(Symbolic{AbstractOperator}) => (map(op -> op * ~o1, ~~ops)...)),
]

#
Expand Down
3 changes: 3 additions & 0 deletions test/test_expand.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ using Test
@test isequal(qexpand((B+C+D)*A), B*A + C*A + D*A)
@test isequal(qexpand(commutator(A, B) * C), A*B*C - B*A*C)

@test isequal(qexpand(A*(BCD)), (A*B)(A*C)(A*D))
@test isequal(qexpand((BCD)*A), (B*A)(C*A)(D*A))

@test isequal(qexpand((AB)*(CD)), (A*C)(B*D))
@test isequal(qexpand((b₁b₂)*(k₁k₂)), (b₁*k₁)*(b₂*k₂))
end
43 changes: 36 additions & 7 deletions test/test_trace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,49 @@ using Test

@bra b₁; @bra b₂;
@ket k₁; @ket k₂;
@op A; @op B; @op C;
@op A; @op B; @op C; @op D; @op E; @op F; @op 𝒪 SpinBasis(1//2)SpinBasis(1//2);

@testset "trace tests" begin
@test isequal(tr(2*A), 2*tr(A))
@test isequal(tr(A+B), tr(A)+tr(B))
@test isequal(tr(k₁*b₁), b₁*k₁)
@test isequal(tr(commutator(A, B)), 0)
@test isequal(tr(()(A, B, C)), tr(A)*tr(B)*tr(C))
@test isequal(tr(ABC), tr(A)*tr(B)*tr(C))
end

exp1 = ABC
exp2 = (k₁*b₁)A + (k₂*b₂)B
exp3 = A(BC + DE)
exp4 = A(BC + DE)*F
@testset "partial trace tests" begin
@test isequal(ptrace(()(A, B, C), 1), tr(A)*(BC))
@test isequal(ptrace(()(A, B, C), 2), tr(B)*(AC))
@test isequal(ptrace(()(A, B, C), 3), tr(C)*(AB))
@test isequal(ptrace((k₁*b₁)A + (k₂*b₂)B, 1), (b₁*k₁)*A + (b₂*k₂)*B)
@test isequal(ptrace((k₁*b₁)A + (k₂*b₂)B, 2), tr(A)*(k₁*b₁) + tr(B)*(k₂*b₂))
@test isequal(ptrace(𝒪, 1), SPartialTrace(𝒪, 1))
@test isequal(QuantumSymbolics.basis(ptrace(𝒪, 1)), SpinBasis(1//2))
@test isequal(ptrace(𝒪, 2), SPartialTrace(𝒪, 2))
@test isequal(QuantumSymbolics.basis(ptrace(𝒪, 2)), SpinBasis(1//2))

@test isequal(ptrace(exp1, 1), tr(A)*(BC))
@test isequal(QuantumSymbolics.basis(ptrace(exp1, 1)), SpinBasis(1//2)SpinBasis(1//2))
@test isequal(ptrace(exp1, 2), tr(B)*(AC))
@test isequal(QuantumSymbolics.basis(ptrace(exp1, 2)), SpinBasis(1//2)SpinBasis(1//2))
@test isequal(ptrace(exp1, 3), tr(C)*(AB))
@test isequal(QuantumSymbolics.basis(ptrace(exp1, 3)), SpinBasis(1//2)SpinBasis(1//2))

@test isequal(ptrace(exp2, 1), (b₁*k₁)*A + (b₂*k₂)*B)
@test isequal(QuantumSymbolics.basis(ptrace(exp2, 1)), SpinBasis(1//2))
@test isequal(ptrace(exp2, 2), tr(A)*(k₁*b₁) + tr(B)*(k₂*b₂))
@test isequal(QuantumSymbolics.basis(ptrace(exp2, 2)), SpinBasis(1//2))

@test isequal(ptrace(exp3, 1), tr(A)*(BC) + tr(A)*(DE))
@test isequal(QuantumSymbolics.basis(ptrace(exp3, 1)), SpinBasis(1//2)SpinBasis(1//2))
@test isequal(ptrace(exp3, 2), tr(B)*(AC) + tr(D)*(AE))
@test isequal(QuantumSymbolics.basis(ptrace(exp3, 2)), SpinBasis(1//2)SpinBasis(1//2))
@test isequal(ptrace(exp3, 3), tr(C)*(AB) + tr(E)*(AD))
@test isequal(QuantumSymbolics.basis(ptrace(exp3, 3)), SpinBasis(1//2)SpinBasis(1//2))

@test isequal(ptrace(exp4, 1), tr(A*F)*((B*F)(C*F)) + tr(A*F)*((D*F)(E*F)))
@test isequal(QuantumSymbolics.basis(ptrace(exp4, 1)), SpinBasis(1//2)SpinBasis(1//2))
@test isequal(ptrace(exp4, 2), tr(B*F)*((A*F)(C*F)) + tr(D*F)*((A*F)(E*F)))
@test isequal(QuantumSymbolics.basis(ptrace(exp4, 2)), SpinBasis(1//2)SpinBasis(1//2))
@test isequal(ptrace(exp4, 3), tr(C*F)*((A*F)(B*F)) + tr(E*F)*((A*F)(D*F)))
@test isequal(QuantumSymbolics.basis(ptrace(exp4, 2)), SpinBasis(1//2)SpinBasis(1//2))
end

0 comments on commit 5997dfb

Please sign in to comment.