Skip to content

Commit

Permalink
Many fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
shinaoka committed Nov 27, 2024
1 parent b930a60 commit 2f08b2f
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ function projcontract(
end

if length(results) == 1
return results[1]
return results
end

res = if length(patchorder) > 0
Expand Down
6 changes: 4 additions & 2 deletions src/partitionedmps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ struct PartitionedMPS
data::OrderedDict{Projector,SubDomainMPS}

function PartitionedMPS(data::AbstractVector{SubDomainMPS})
length(data) > 0 || error("Empty data")
sites_all = [siteinds(prjmps) for prjmps in data]
for n in 2:length(data)
Set(sites_all[n]) == Set(sites_all[1]) || error("Sitedims mismatch")
Expand All @@ -32,9 +33,10 @@ end
ITensors.siteinds(obj::PartitionedMPS) = siteindices(obj)

"""
Get the number of sites in the PartitionedMPS
Get the number of the data in the PartitionedMPS.
This is NOT the number of sites in the PartitionedMPS.
"""
Base.length(obj::PartitionedMPS) = length(first(obj.data))
Base.length(obj::PartitionedMPS) = length(obj.data)

"""
Indexing for PartitionedMPS. This is deprecated and will be removed in the future.
Expand Down
4 changes: 3 additions & 1 deletion src/projector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ Projector(singleproj::Pair{Index{T},Int}) where {T} =
Projector(Dict{Index,Int}(singleproj.first => singleproj.second))

function Base.hash(p::Projector, h::UInt)
tmp = hash(collect(Iterators.flatten(((hash(k, h), hash(v, h)) for (k, v) in p.data))))
tmp = hash(
sort(collect(Iterators.flatten(((hash(k, h), hash(v, h)) for (k, v) in p.data))))
)
return Base.hash(tmp, h)
end

Expand Down
7 changes: 6 additions & 1 deletion test/partitionedmps_tests.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
using Test

using ITensors
using ITensorMPS

import PartitionedMPSs: Projector, project, SubDomainMPS, PartitionedMPS

@testset "PartitionedMPS.jl" begin
@testset "partitionedmps.jl" begin
@testset "two blocks" begin
N = 3
sitesx = [Index(2, "x=$n") for n in 1:N]
Expand All @@ -22,6 +23,10 @@ import PartitionedMPSs: Projector, project, SubDomainMPS, PartitionedMPS
@test_throws ErrorException PartitionedMPS([prjΨ, prjΨ1])
@test_throws ErrorException PartitionedMPS([prjΨ1, prjΨ1])

# Iterator and length
@test length(PartitionedMPS(prjΨ1)) == 1
@test length([(k, v) for (k, v) in PartitionedMPS(prjΨ1)]) == 1

Ψreconst = PartitionedMPS(prjΨ1) + PartitionedMPS(prjΨ2)
@test Ψreconst[1] == prjΨ1
@test Ψreconst[2] == prjΨ2
Expand Down

0 comments on commit 2f08b2f

Please sign in to comment.