From 2e9dfb7864e5b10d431f15e384d6d986ee814ae2 Mon Sep 17 00:00:00 2001 From: Omar Elrefaei Date: Sat, 19 Oct 2024 11:01:42 -0400 Subject: [PATCH] fix Aqua's reported piracies and method ambiguities --- src/QSymbolicsBase/basic_ops_homogeneous.jl | 13 +++++++------ src/QSymbolicsBase/basic_superops.jl | 2 ++ test/test_aqua.jl | 6 ++++-- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/src/QSymbolicsBase/basic_ops_homogeneous.jl b/src/QSymbolicsBase/basic_ops_homogeneous.jl index ed41ff8..1263438 100644 --- a/src/QSymbolicsBase/basic_ops_homogeneous.jl +++ b/src/QSymbolicsBase/basic_ops_homogeneous.jl @@ -29,7 +29,7 @@ arguments(x::SScaled) = [x.coeff,x.obj] operation(x::SScaled) = * head(x::SScaled) = :* children(x::SScaled) = [:*,x.coeff,x.obj] -function Base.:(*)(c, x::Symbolic{T}) where {T<:QObj} +function Base.:(*)(c::U, x::Symbolic{T}) where {U<:Union{Number, Symbolic{<:Number}},T<:QObj} if (isa(c, Number) && iszero(c)) || iszero(x) SZero{T}() elseif _isone(c) @@ -40,9 +40,9 @@ function Base.:(*)(c, x::Symbolic{T}) where {T<:QObj} SScaled{T}(c, x) end end -Base.:(*)(x::Symbolic{T}, c) where {T<:QObj} = c*x +Base.:(*)(x::Symbolic{T}, c::Number) where {T<:QObj} = c*x Base.:(*)(x::Symbolic{T}, y::Symbolic{S}) where {T<:QObj,S<:QObj} = throw(ArgumentError("multiplication between $(typeof(x)) and $(typeof(y)) is not defined; maybe you are looking for a tensor product `tensor`")) -Base.:(/)(x::Symbolic{T}, c) where {T<:QObj} = iszero(c) ? throw(DomainError(c,"cannot divide QSymbolics expressions by zero")) : (1/c)*x +Base.:(/)(x::Symbolic{T}, c::Number) where {T<:QObj} = iszero(c) ? throw(DomainError(c,"cannot divide QSymbolics expressions by zero")) : (1/c)*x basis(x::SScaled) = basis(x.obj) const SScaledKet = SScaled{AbstractKet} @@ -94,13 +94,13 @@ arguments(x::SAdd) = x._arguments_precomputed operation(x::SAdd) = + head(x::SAdd) = :+ children(x::SAdd) = [:+; x._arguments_precomputed] -function Base.:(+)(xs::Vararg{Symbolic{T},N}) where {T<:QObj,N} +function Base.:(+)(x::Symbolic{T}, xs::Vararg{Symbolic{T}, N}) where {T<:QObj, N} + xs = (x, xs...) xs = collect(xs) f = first(xs) nonzero_terms = filter!(x->!iszero(x),xs) isempty(nonzero_terms) ? f : SAdd{T}(countmap_flatten(nonzero_terms, SScaled{T})) end -Base.:(+)(xs::Vararg{Symbolic{<:QObj},0}) = 0 # to avoid undefined type parameters issue in the above method basis(x::SAdd) = basis(first(x.dict).first) const SAddBra = SAdd{AbstractBra} @@ -137,7 +137,8 @@ arguments(x::SMulOperator) = x.terms operation(x::SMulOperator) = * head(x::SMulOperator) = :* children(x::SMulOperator) = [:*;x.terms] -function Base.:(*)(xs::Symbolic{AbstractOperator}...) +function Base.:(*)(x::Symbolic{AbstractOperator}, xs::Vararg{Symbolic{AbstractOperator}, N}) where {N} + xs = (x, xs...) zero_ind = findfirst(x->iszero(x), xs) if isnothing(zero_ind) if any(x->!(samebases(basis(x),basis(first(xs)))),xs) diff --git a/src/QSymbolicsBase/basic_superops.jl b/src/QSymbolicsBase/basic_superops.jl index b52792c..2cea4c2 100644 --- a/src/QSymbolicsBase/basic_superops.jl +++ b/src/QSymbolicsBase/basic_superops.jl @@ -29,6 +29,8 @@ kraus(xs::Symbolic{AbstractOperator}...) = KrausRepr(collect(xs)) basis(x::KrausRepr) = basis(first(x.krausops)) Base.:(*)(sop::KrausRepr, op::Symbolic{AbstractOperator}) = (+)((i*op*dagger(i) for i in sop.krausops)...) Base.:(*)(sop::KrausRepr, k::Symbolic{AbstractKet}) = (+)((i*SProjector(k)*dagger(i) for i in sop.krausops)...) +Base.:(*)(sop::KrausRepr, k::SZeroOperator) = SZeroOperator() +Base.:(*)(sop::KrausRepr, k::SZeroKet) = SZeroOperator() Base.show(io::IO, x::KrausRepr) = print(io, "𝒦("*join([symbollabel(i) for i in x.krausops], ",")*")") ## diff --git a/test/test_aqua.jl b/test/test_aqua.jl index 24f6440..61e30e3 100644 --- a/test/test_aqua.jl +++ b/test/test_aqua.jl @@ -1,7 +1,9 @@ @testitem "Aqua" tags=[:aqua] begin using Aqua + import QuantumInterface as QI + own_types = [QI.AbstractBra, QI.AbstractKet, QI.AbstractSuperOperator, QI.AbstractOperator] Aqua.test_all(QuantumSymbolics, - ambiguities=(;broken=true), - piracies=(;broken=true), + ambiguities=(), + piracies=(;treat_as_own=own_types), ) end