diff --git a/src/QSymbolicsBase/QSymbolicsBase.jl b/src/QSymbolicsBase/QSymbolicsBase.jl index 80a553d..82737fa 100644 --- a/src/QSymbolicsBase/QSymbolicsBase.jl +++ b/src/QSymbolicsBase/QSymbolicsBase.jl @@ -35,8 +35,9 @@ export SymQObj,QObj, SScaled,SScaledBra,SScaledOperator,SScaledKet, STensorBra,STensorKet,STensorOperator, SZeroBra,SZeroKet,SZeroOperator, - SProjector,MixedState,IdentityOp,SInvOperator, - SApplyKet,SApplyBra,SMulOperator,SSuperOpApply,SCommutator,SAnticommutator,SDagger,SBraKet,SOuterKetBra, + MixedState,IdentityOp, + SProjector,SDagger,STrace,SPartialTrace,SInvOperator, + SApplyKet,SApplyBra,SMulOperator,SSuperOpApply,SCommutator,SAnticommutator,SBraKet,SOuterKetBra, HGate,XGate,YGate,ZGate,CPHASEGate,CNOTGate, XBasisState,YBasisState,ZBasisState, NumberOp,CreateOp,DestroyOp, diff --git a/src/QSymbolicsBase/basic_ops_homogeneous.jl b/src/QSymbolicsBase/basic_ops_homogeneous.jl index 2714e52..8656b18 100644 --- a/src/QSymbolicsBase/basic_ops_homogeneous.jl +++ b/src/QSymbolicsBase/basic_ops_homogeneous.jl @@ -147,7 +147,7 @@ function Base.:(*)(xs::Symbolic{AbstractOperator}...) end end Base.show(io::IO, x::SMulOperator) = print(io, join(map(string, arguments(x)),"")) -basis(x::SMulOperator) = basis(x.terms) +basis(x::SMulOperator) = basis(first(x.terms)) """Tensor product of quantum objects (kets, operators, or bras) diff --git a/src/QSymbolicsBase/linalg.jl b/src/QSymbolicsBase/linalg.jl index 7d40ff7..4dd6ffc 100644 --- a/src/QSymbolicsBase/linalg.jl +++ b/src/QSymbolicsBase/linalg.jl @@ -129,7 +129,7 @@ Base.isequal(x::STrace, y::STrace) = isequal(x.op, y.op) """Partial trace over system i of a composite quantum system ```jldoctest -julia> @op 𝒪 SpinBasis(1//2) ⊗ SpinBasis(1//2); +julia> @op 𝒪 SpinBasis(1//2)⊗SpinBasis(1//2); julia> op = ptrace(𝒪, 1) tr1(𝒪) @@ -180,21 +180,15 @@ function basis(x::SPartialTrace) end Base.show(io::IO, x::SPartialTrace) = print(io, "tr$(x.sys)($(x.obj))") function ptrace(x::Symbolic{AbstractOperator}, s) - if isa(basis(x), CompositeBasis) - SPartialTrace(x, s) + ex = isexpr(x) ? qexpand(x) : x + if isa(ex, typeof(x)) + if isa(basis(x), CompositeBasis) + SPartialTrace(x, s) + else + throw(ArgumentError("cannot take partial trace of a single quantum system")) + end else - throw(ArgumentError("cannot take partial trace of a single quantum system")) - end -end -function ptrace(x::STensorOperator, s) - terms = arguments(x) - newterms = [] - if isa(basis(terms[s]), CompositeBasis) - SPartial(x, s) - else - sys_op = terms[s] - new_terms = deleteat!(copy(terms), s) - isone(length(new_terms)) ? tr(sys_op)*first(new_terms) : tr(sys_op)*STensorOperator(new_terms) + ptrace(ex, s) end end function ptrace(x::SAddOperator, s) @@ -212,6 +206,22 @@ function ptrace(x::SAddOperator, s) end (+)(add_terms...) end +function ptrace(x::STensorOperator, s) + ex = qexpand(x) + if isa(ex, SAddOperator) + ptrace(ex, s) + else + terms = arguments(ex) + newterms = [] + if isa(basis(terms[s]), CompositeBasis) + SPartial(ex, s) + else + sys_op = terms[s] + new_terms = deleteat!(copy(terms), s) + isone(length(new_terms)) ? tr(sys_op)*first(new_terms) : tr(sys_op)*STensorOperator(new_terms) + end + end +end """Inverse Operator diff --git a/src/QSymbolicsBase/rules.jl b/src/QSymbolicsBase/rules.jl index 38db184..66494ae 100644 --- a/src/QSymbolicsBase/rules.jl +++ b/src/QSymbolicsBase/rules.jl @@ -102,9 +102,10 @@ RULES_EXPAND = [ @rule(+(~~ops) ⊗ ~o1 => +(map(op -> op ⊗ ~o1, ~~ops)...)), @rule(~o1 * +(~~ops) => +(map(op -> ~o1 * op, ~~ops)...)), @rule(+(~~ops) * ~o1 => +(map(op -> op * ~o1, ~~ops)...)), - @rule(+(~~ops) * ~o1 => +(map(op -> op * ~o1, ~~ops)...)), + @rule(⊗(~~ops1::_vecisa(Symbolic{AbstractBra})) * ⊗(~~ops2::_vecisa(Symbolic{AbstractKet})) => *(map(*, ~~ops1, ~~ops2)...)), @rule(⊗(~~ops1::_vecisa(Symbolic{AbstractOperator})) * ⊗(~~ops2::_vecisa(Symbolic{AbstractOperator})) => ⊗(map(*, ~~ops1, ~~ops2)...)), - @rule(⊗(~~ops1::_vecisa(Symbolic{AbstractBra})) * ⊗(~~ops2::_vecisa(Symbolic{AbstractKet})) => *(map(*, ~~ops1, ~~ops2)...)) + @rule(~o1::_isa(Symbolic{AbstractOperator}) * ⊗(~~ops) => ⊗(map(op -> ~o1 * op, ~~ops)...)), + @rule(⊗(~~ops) * ~o1::_isa(Symbolic{AbstractOperator}) => ⊗(map(op -> op * ~o1, ~~ops)...)), ] # diff --git a/test/test_expand.jl b/test/test_expand.jl index fd7fc30..8dab6b3 100644 --- a/test/test_expand.jl +++ b/test/test_expand.jl @@ -26,6 +26,9 @@ using Test @test isequal(qexpand((B+C+D)*A), B*A + C*A + D*A) @test isequal(qexpand(commutator(A, B) * C), A*B*C - B*A*C) + @test isequal(qexpand(A*(B⊗C⊗D)), (A*B)⊗(A*C)⊗(A*D)) + @test isequal(qexpand((B⊗C⊗D)*A), (B*A)⊗(C*A)⊗(D*A)) + @test isequal(qexpand((A⊗B)*(C⊗D)), (A*C)⊗(B*D)) @test isequal(qexpand((b₁⊗b₂)*(k₁⊗k₂)), (b₁*k₁)*(b₂*k₂)) end \ No newline at end of file diff --git a/test/test_trace.jl b/test/test_trace.jl index 1343b89..f4e9f5c 100644 --- a/test/test_trace.jl +++ b/test/test_trace.jl @@ -3,20 +3,49 @@ using Test @bra b₁; @bra b₂; @ket k₁; @ket k₂; -@op A; @op B; @op C; +@op A; @op B; @op C; @op D; @op E; @op F; @op 𝒪 SpinBasis(1//2)⊗SpinBasis(1//2); @testset "trace tests" begin @test isequal(tr(2*A), 2*tr(A)) @test isequal(tr(A+B), tr(A)+tr(B)) @test isequal(tr(k₁*b₁), b₁*k₁) @test isequal(tr(commutator(A, B)), 0) - @test isequal(tr((⊗)(A, B, C)), tr(A)*tr(B)*tr(C)) + @test isequal(tr(A⊗B⊗C), tr(A)*tr(B)*tr(C)) end +exp1 = A⊗B⊗C +exp2 = (k₁*b₁)⊗A + (k₂*b₂)⊗B +exp3 = A⊗(B⊗C + D⊗E) +exp4 = A⊗(B⊗C + D⊗E)*F @testset "partial trace tests" begin - @test isequal(ptrace((⊗)(A, B, C), 1), tr(A)*(B⊗C)) - @test isequal(ptrace((⊗)(A, B, C), 2), tr(B)*(A⊗C)) - @test isequal(ptrace((⊗)(A, B, C), 3), tr(C)*(A⊗B)) - @test isequal(ptrace((k₁*b₁)⊗A + (k₂*b₂)⊗B, 1), (b₁*k₁)*A + (b₂*k₂)*B) - @test isequal(ptrace((k₁*b₁)⊗A + (k₂*b₂)⊗B, 2), tr(A)*(k₁*b₁) + tr(B)*(k₂*b₂)) + @test isequal(ptrace(𝒪, 1), SPartialTrace(𝒪, 1)) + @test isequal(QuantumSymbolics.basis(ptrace(𝒪, 1)), SpinBasis(1//2)) + @test isequal(ptrace(𝒪, 2), SPartialTrace(𝒪, 2)) + @test isequal(QuantumSymbolics.basis(ptrace(𝒪, 2)), SpinBasis(1//2)) + + @test isequal(ptrace(exp1, 1), tr(A)*(B⊗C)) + @test isequal(QuantumSymbolics.basis(ptrace(exp1, 1)), SpinBasis(1//2)⊗SpinBasis(1//2)) + @test isequal(ptrace(exp1, 2), tr(B)*(A⊗C)) + @test isequal(QuantumSymbolics.basis(ptrace(exp1, 2)), SpinBasis(1//2)⊗SpinBasis(1//2)) + @test isequal(ptrace(exp1, 3), tr(C)*(A⊗B)) + @test isequal(QuantumSymbolics.basis(ptrace(exp1, 3)), SpinBasis(1//2)⊗SpinBasis(1//2)) + + @test isequal(ptrace(exp2, 1), (b₁*k₁)*A + (b₂*k₂)*B) + @test isequal(QuantumSymbolics.basis(ptrace(exp2, 1)), SpinBasis(1//2)) + @test isequal(ptrace(exp2, 2), tr(A)*(k₁*b₁) + tr(B)*(k₂*b₂)) + @test isequal(QuantumSymbolics.basis(ptrace(exp2, 2)), SpinBasis(1//2)) + + @test isequal(ptrace(exp3, 1), tr(A)*(B⊗C) + tr(A)*(D⊗E)) + @test isequal(QuantumSymbolics.basis(ptrace(exp3, 1)), SpinBasis(1//2)⊗SpinBasis(1//2)) + @test isequal(ptrace(exp3, 2), tr(B)*(A⊗C) + tr(D)*(A⊗E)) + @test isequal(QuantumSymbolics.basis(ptrace(exp3, 2)), SpinBasis(1//2)⊗SpinBasis(1//2)) + @test isequal(ptrace(exp3, 3), tr(C)*(A⊗B) + tr(E)*(A⊗D)) + @test isequal(QuantumSymbolics.basis(ptrace(exp3, 3)), SpinBasis(1//2)⊗SpinBasis(1//2)) + + @test isequal(ptrace(exp4, 1), tr(A*F)*((B*F)⊗(C*F)) + tr(A*F)*((D*F)⊗(E*F))) + @test isequal(QuantumSymbolics.basis(ptrace(exp4, 1)), SpinBasis(1//2)⊗SpinBasis(1//2)) + @test isequal(ptrace(exp4, 2), tr(B*F)*((A*F)⊗(C*F)) + tr(D*F)*((A*F)⊗(E*F))) + @test isequal(QuantumSymbolics.basis(ptrace(exp4, 2)), SpinBasis(1//2)⊗SpinBasis(1//2)) + @test isequal(ptrace(exp4, 3), tr(C*F)*((A*F)⊗(B*F)) + tr(E*F)*((A*F)⊗(D*F))) + @test isequal(QuantumSymbolics.basis(ptrace(exp4, 2)), SpinBasis(1//2)⊗SpinBasis(1//2)) end