Skip to content

Commit

Permalink
add truncation functionalities for iTT
Browse files Browse the repository at this point in the history
  • Loading branch information
stecrotti committed Dec 9, 2024
1 parent f3333b7 commit de36384
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 8 deletions.
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@ Lazy = "50d2b5c4-7a5e-59d5-8109-a42b560f39c0"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogarithmicNumbers = "aa2f6b4e-9042-5d33-9679-40d3a6b85899"
MKL = "33e6dc65-8f57-5167-99aa-e5a354878fb2"
MPSKit = "bb1c41ca-d63c-52ed-829e-0820dda26502"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
TensorCast = "02d47bb6-7ce6-556a-be16-bb1710789e2b"
TensorKit = "07d1fe3e-3e46-537d-9eac-e9e13d0d4cec"
Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"

[compat]
Expand All @@ -20,8 +22,10 @@ Lazy = "0.15"
LinearAlgebra = "1.8"
LogarithmicNumbers = "1.4"
MKL = "0.6.3, 0.7"
MPSKit = "0.11.5"
Random = "1.8"
StatsBase = "0.33, 0.34"
TensorCast = "0.4"
TensorKit = "0.12.0"
Tullio = "0.3"
julia = "1.8"
7 changes: 5 additions & 2 deletions src/TensorTrains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@ using Lazy: @forward
using LinearAlgebra: LinearAlgebra, svd, norm, tr, I, dot, normalize!
using LogarithmicNumbers: Logarithmic
using MKL
using MPSKit: InfiniteMPS, DenseMPO, VUMPS, approximate, dot, add_util_leg, site_type, physicalspace
using Random: AbstractRNG, default_rng
using StatsBase: StatsBase, sample!, sample
using TensorCast: @cast, TensorCast
using TensorKit: TensorMap, , ℝ, id, storagetype
using Tullio: @tullio


export
getindex, iterate, firstindex, lastindex, setindex!, eachindex, length, show,
SVDTrunc, TruncBond, TruncThresh, TruncBondMax, TruncBondThresh, summary_compact,
Expand All @@ -24,7 +25,8 @@ export
# Uniform Tensor Trains
AbstractUniformTensorTrain, UniformTensorTrain, InfiniteUniformTensorTrain,
symmetrized_uniform_tensor_train, periodic_tensor_train,
flat_infinite_uniform_tt, rand_infinite_uniform_tt
flat_infinite_uniform_tt, rand_infinite_uniform_tt,
TruncVUMPS


include("utils.jl")
Expand All @@ -36,5 +38,6 @@ include("periodic_tensor_train.jl")
# Uniform Tensor Trains
include("UniformTensorTrains/uniform_tensor_train.jl")
include("UniformTensorTrains/transfer_operator.jl")
include("UniformTensorTrains/vumps_trunc.jl")

end # end module
5 changes: 0 additions & 5 deletions src/UniformTensorTrains/uniform_tensor_train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,6 @@ function orthogonalize_right!(::AbstractUniformTensorTrain; kw...)
error("Not implemented")
end

function compress!(A::AbstractUniformTensorTrain; kw...)
@warn "Compressing a uniform Tensor Train: I'm not doing anyhing (yet)"
return A
end

function _compose(f, ::AbstractUniformTensorTrain, ::AbstractUniformTensorTrain)
error("Not implemented")
end
Expand Down
72 changes: 72 additions & 0 deletions src/UniformTensorTrains/vumps_trunc.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""
TruncVUMPS{TI, TF} <: SVDTrunc
A type used to perform truncations of an [`InfiniteUniformTensorTrain`](@ref) to a target bond size `d`.
It uses the Variational Uniform Matrix Product States (VUMPS) algorithm from MPSKit.jl.
# FIELDS
- `d`: target bond dimension
- `maxiter = 100`: max number of iterations for the VUMPS algorithm
- `tol = 1e-14`: tolerance for the VUMPS algorithm
```@example
p = rand_infinite_uniform_tt(10, 2, 2)
compress!(p, TruncVUMPS(5))
```
"""
struct TruncVUMPS{TI<:Integer, TF<:Real} <: SVDTrunc
d :: TI
maxiter :: TI
tol :: TF
end
TruncVUMPS(d::Integer; maxiter=100, tol=1e-14) = TruncVUMPS(d, maxiter, tol)

summary(svd_trunc::TruncVUMPS) = "VUMPS truncation to bond size m'="*string(svd_trunc.d)

function truncate_vumps(A::Array{F,3}, d;
init = rand(d, size(A,2), d),
maxiter = 100, kw_vumps...) where {F}
ψ = InfiniteMPS([TensorMap(init, (ℝ^d ^size(A,2)), ℝ^d)])
Q = size(A, 2)
m = size(A, 1)
@assert size(A, 3) == m
t = TensorMap(A,(ℝ^m ^Q), ℝ^m) # the same but as a type digestible by MPSKit.jl
ψ₀ = InfiniteMPS([t])
II = DenseMPO([add_util_leg(id(storagetype(site_type(ψ₀)), physicalspace(ψ₀, i)))
for i in 1:length(ψ₀)])
alg = VUMPS(; maxiter, verbosity=0, kw_vumps...) # variational approximation algorithm
# alg = IDMRG1(; maxiter)
@assert typeof(ψ) == typeof(ψ₀)
ψ_, = approximate(ψ, (II, ψ₀), alg) # do the truncation
@assert typeof(ψ) == typeof(ψ_)

ovl = abs(dot(ψ_, ψ₀))
B = reshape(only(ψ_.AL).data, d, Q, d)
return B, ovl, ψ_
end

function compress!(A::InfiniteUniformTensorTrain; svd_trunc::TruncVUMPS=TruncVUMPS(4),
is_orthogonal::Symbol=:none, init = rand_infinite_uniform_tt(svd_trunc.d, size(A.tensor)[3:end]...))
(; d, maxiter, tol) = svd_trunc
qs = size(A.tensor)[3:end]
B = reshape(A.tensor, size(A.tensor)[1:2]..., prod(qs))
Bperm = permutedims(B, (1,3,2))
# reduce or expand `init` to match bond dimension `svd_trunc.d`
s = size(init.tensor)
init_resized = if s[1] != svd_trunc.d
init_ = InfiniteUniformTensorTrain(zeros(svd_trunc.d, svd_trunc.d, size(A.tensor)[3:end]...))
init_.tensor[1:s[1],1:s[2],fill(:,length(qs))...] = init.tensor
init_
else
init
end
@debug begin
if size(permutedims(_reshape1(init_resized.tensor), (1,3,2))) != size(rand(svd_trunc.d, prod(size(A.tensor)[3:end]), svd_trunc.d))
@show size(permutedims(_reshape1(init_resized.tensor), (1,3,2))) size(rand(svd_trunc.d, prod(size(A.tensor)[3:end]), svd_trunc.d))
end
end
Btruncperm, = truncate_vumps(Bperm, d; maxiter, tol, init = permutedims(_reshape1(init_resized.tensor), (1,3,2)))
Btrunc = permutedims(Btruncperm, (1,3,2))
A.tensor = reshape(Btrunc, size(Btrunc)[1:2]..., qs...)
return A
end
13 changes: 12 additions & 1 deletion test/uniform_tensor_train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@
@test_throws ArgumentError (A[3] = rand(rng, 4,4,2,3))
@test_throws "Not implemented" orthogonalize_left!(A)
@test_throws "Not implemented" orthogonalize_right!(A)
@test_warn "Compressing a uniform Tensor Train: I'm not doing anyhing (yet)" compress!(A)
tensor = rand(rng, 4,4,2,3)
B = UniformTensorTrain(tensor, L)
@test_throws "Not implemented" A - B
Expand Down Expand Up @@ -130,4 +129,16 @@ end

r = flat_infinite_uniform_tt(2, 3, 4)
@test dot(r, r) 1
end

@testset "VUMPS truncations" begin
rng = MersenneTwister(0)
A = rand(rng, 10,10,3,4)
p = InfiniteUniformTensorTrain(A)
q = deepcopy(p)
compress!(p; svd_trunc=TruncVUMPS(8))
@test size(p.tensor)[1:2] == (8, 8)
marg = real(only(marginals(q)))
marg_compressed = real(only(marginals(p)))
@test isapprox(marg, marg_compressed, atol=1e-5)
end

0 comments on commit de36384

Please sign in to comment.