From f667e4f5e3338e29d9e0f94f37d1d5f509cb8a70 Mon Sep 17 00:00:00 2001 From: Hiroshi Shinaoka Date: Thu, 5 Dec 2024 12:01:36 -0500 Subject: [PATCH] Add coeffs options to add --- src/partitionedmps.jl | 30 ++++++++++++++++++++++++++---- test/partitionedmps_tests.jl | 9 +++++++-- 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/src/partitionedmps.jl b/src/partitionedmps.jl index fe05764..6a00ba0 100644 --- a/src/partitionedmps.jl +++ b/src/partitionedmps.jl @@ -118,22 +118,44 @@ function Base.:+( alg="directsum", cutoff=0.0, maxdim=typemax(Int), + coeffs=(1.0, 1.0), kwargs..., )::PartitionedMPS + result = PartitionedMPS() + return add!(result, a, b; alg, cutoff, maxdim, coeffs, kwargs...) +end + +function add!( + result::PartitionedMPS, + a::PartitionedMPS, + b::PartitionedMPS; + alg="directsum", + cutoff=0.0, + maxdim=typemax(Int), + overwrite=true, + coeffs=(1.0, 1.0), + kwargs..., +)::PartitionedMPS + length(coeffs) == 2 || error("coeffs must be a tuple of length 2") data = SubDomainMPS[] for k in unique(vcat(collect(keys(a)), collect(keys(b)))) # preserve order + if k ∈ keys(result) && !overwrite + continue + end if k ∈ keys(a) && k ∈ keys(b) a[k].projector == b[k].projector || error("Projectors mismatch at $(k)") - push!(data, +(a[k], b[k]; alg, cutoff, maxdim, kwargs...)) + push!( + data, +(coeffs[1] * a[k], coeffs[2] * b[k]; alg, cutoff, maxdim, kwargs...) + ) elseif k ∈ keys(a) - push!(data, a[k]) + push!(data, coeffs[1] * a[k]) elseif k ∈ keys(b) - push!(data, b[k]) + push!(data, coeffs[2] * b[k]) else error("Something went wrong") end end - return PartitionedMPS(data) + return append!(result, data) end function Base.:*(a::PartitionedMPS, b::Number)::PartitionedMPS diff --git a/test/partitionedmps_tests.jl b/test/partitionedmps_tests.jl index 44fd839..67c367c 100644 --- a/test/partitionedmps_tests.jl +++ b/test/partitionedmps_tests.jl @@ -30,10 +30,15 @@ import PartitionedMPSs: PartitionedMPSs, Projector, project, SubDomainMPS, Parti @test length([(k, v) for (k, v) in PartitionedMPS(prjΨ1)]) == 1 Ψreconst = PartitionedMPS(prjΨ1) + PartitionedMPS(prjΨ2) - @test Ψreconst[1] == prjΨ1 - @test Ψreconst[2] == prjΨ2 + @test Ψreconst[1] ≈ prjΨ1 + @test Ψreconst[2] ≈ prjΨ2 @test MPS(Ψreconst) ≈ Ψ @test ITensors.norm(Ψreconst) ≈ ITensors.norm(MPS(Ψreconst)) + + # Summation + coeffs = (1.1, 0.9) + @test MPS(+(PartitionedMPS(prjΨ1), PartitionedMPS(prjΨ2); coeffs=coeffs)) ≈ + coeffs[1] * MPS(prjΨ1) + coeffs[2] * MPS(prjΨ2) end @testset "two blocks (general key)" begin