From e978bbbe5e52b3312cdb5cae48c19a3cd5f1d23f Mon Sep 17 00:00:00 2001 From: Hiroshi Shinaoka Date: Sat, 22 Jun 2024 13:00:45 +0200 Subject: [PATCH] Implement makesitediagonal and extractdiagonal --- src/tag.jl | 2 +- src/util.jl | 28 ++++++++++++++++++++++++++++ test/transformer_tests.jl | 2 ++ 3 files changed, 31 insertions(+), 1 deletion(-) diff --git a/src/tag.jl b/src/tag.jl index 770611d..e17a163 100644 --- a/src/tag.jl +++ b/src/tag.jl @@ -50,7 +50,7 @@ function findallsites_by_tag(sites::Vector{Vector{Index{T}}}; tag::String="x", sitesflatten = collect(Iterators.flatten(sites)) for n in 1:maxnsites tag_ = tag * "=$n" - idx = findall(hastags(tag_), sitesflatten) + idx = findall(i -> hastags(i, tag_) && hasplev(i, 0), sitesflatten) if length(idx) == 0 break elseif length(idx) > 1 diff --git a/src/util.jl b/src/util.jl index 0a0f547..bbac3d5 100644 --- a/src/util.jl +++ b/src/util.jl @@ -495,3 +495,31 @@ function makesitediagonal(M::AbstractMPS, tag::String)::MPS return MPS(collect(M_)) end + +""" +Extract diagonal components +""" +function extractdiagonal(M::AbstractMPS, tag::String)::MPS + M_ = deepcopy(MPO(collect(M))) + sites = siteinds(M_) + + target_positions = findallsites_by_tag(siteinds(M_); tag=tag) + + for t in eachindex(target_positions) + i, j = target_positions[t] + M_[i] = _extract_diagonal(M_[i], sites[i][j], sites[i][j]') + end + + return MPS(collect(M_)) +end + +function _extract_diagonal(t, site::Index{T}, site2::Index{T}) where {T<:Number} + dim(site) == dim(site2) || error("Dimension mismatch") + restinds = uniqueinds(inds(t), site, site2) + newdata = zeros(eltype(t), dim.(restinds)..., dim(site)) + olddata = Array(t, restinds..., site, site2) + for i in 1:dim(site) + newdata[.., i] = olddata[.., i, i] + end + return ITensor(newdata, restinds..., site) +end diff --git a/test/transformer_tests.jl b/test/transformer_tests.jl index 272009a..6096f68 100644 --- a/test/transformer_tests.jl +++ b/test/transformer_tests.jl @@ -76,6 +76,8 @@ end f_ref[i] = g_reconst[mod(2^nbit - (i - 1), 2^nbit) + 1] end f_ref[1] *= bc + + @test f_reconst ≈ f_ref end @testset "reverseaxis2" for nbit in 2:3, rev_carrydirec in [true, false]