Skip to content

Commit

Permalink
Add coeffs options to add
Browse files Browse the repository at this point in the history
  • Loading branch information
shinaoka committed Dec 5, 2024
1 parent 68df451 commit f667e4f
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 6 deletions.
30 changes: 26 additions & 4 deletions src/partitionedmps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,22 +118,44 @@ function Base.:+(
alg="directsum",
cutoff=0.0,
maxdim=typemax(Int),
coeffs=(1.0, 1.0),
kwargs...,
)::PartitionedMPS
result = PartitionedMPS()
return add!(result, a, b; alg, cutoff, maxdim, coeffs, kwargs...)
end

function add!(
result::PartitionedMPS,
a::PartitionedMPS,
b::PartitionedMPS;
alg="directsum",
cutoff=0.0,
maxdim=typemax(Int),
overwrite=true,
coeffs=(1.0, 1.0),
kwargs...,
)::PartitionedMPS
length(coeffs) == 2 || error("coeffs must be a tuple of length 2")
data = SubDomainMPS[]
for k in unique(vcat(collect(keys(a)), collect(keys(b)))) # preserve order
if k keys(result) && !overwrite
continue
end
if k keys(a) && k keys(b)
a[k].projector == b[k].projector || error("Projectors mismatch at $(k)")
push!(data, +(a[k], b[k]; alg, cutoff, maxdim, kwargs...))
push!(
data, +(coeffs[1] * a[k], coeffs[2] * b[k]; alg, cutoff, maxdim, kwargs...)
)
elseif k keys(a)
push!(data, a[k])
push!(data, coeffs[1] * a[k])
elseif k keys(b)
push!(data, b[k])
push!(data, coeffs[2] * b[k])
else
error("Something went wrong")
end
end
return PartitionedMPS(data)
return append!(result, data)
end

function Base.:*(a::PartitionedMPS, b::Number)::PartitionedMPS
Expand Down
9 changes: 7 additions & 2 deletions test/partitionedmps_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,15 @@ import PartitionedMPSs: PartitionedMPSs, Projector, project, SubDomainMPS, Parti
@test length([(k, v) for (k, v) in PartitionedMPS(prjΨ1)]) == 1

Ψreconst = PartitionedMPS(prjΨ1) + PartitionedMPS(prjΨ2)
@test Ψreconst[1] == prjΨ1
@test Ψreconst[2] == prjΨ2
@test Ψreconst[1] prjΨ1
@test Ψreconst[2] prjΨ2
@test MPS(Ψreconst) Ψ
@test ITensors.norm(Ψreconst) ITensors.norm(MPS(Ψreconst))

# Summation
coeffs = (1.1, 0.9)
@test MPS(+(PartitionedMPS(prjΨ1), PartitionedMPS(prjΨ2); coeffs=coeffs))
coeffs[1] * MPS(prjΨ1) + coeffs[2] * MPS(prjΨ2)
end

@testset "two blocks (general key)" begin
Expand Down

0 comments on commit f667e4f

Please sign in to comment.